Unverified Commit ae8dbe6d authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Feature][Performance] Implement NCCL wrapper for communicating NodeEmbeddings...


[Feature][Performance] Implement NCCL wrapper for communicating NodeEmbeddings and sparse gradients. (#2825)

* Split NCCL wrapper from sparse optimizer and sparse embedding

* Add more unit tests for single node nccl

* Fix unit test for tf

* Switch to device histogram

* Fix histgram issues

* Finish migration to histogram

* Handle cases with zero send/recieve data

* Start on partition object

* Get compiling

* Updates

* Add unit tests

* Switch to partition object

* Fix linting issues

* Rename partition file

* Add python doc

* Fix python assert and finish doxygen comments

* Remove stubs for range based partition to satisfy pylint

* Wrap unit test in GPU only

* Wrap explicit cuda call in ifdef

* Merge with partition.py

* update docstrings

* Cleanup partition_op

* Add Workspace object

* Switch to using workspace object

* Move last remainder based function out of nccl_api

* Add error messages

* Update docs with examples

* Fix linting erros
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 0e9259b4
...@@ -142,6 +142,7 @@ file(GLOB DGL_SRC ...@@ -142,6 +142,7 @@ file(GLOB DGL_SRC
src/geometry/*.cc src/geometry/*.cc
src/geometry/cpu/*.cc src/geometry/cpu/*.cc
src/dataloading/*.cc src/dataloading/*.cc
src/partition/*.cc
) )
file(GLOB_RECURSE DGL_SRC_1 file(GLOB_RECURSE DGL_SRC_1
...@@ -157,6 +158,10 @@ list(APPEND DGL_SRC ${DGL_SRC_1}) ...@@ -157,6 +158,10 @@ list(APPEND DGL_SRC ${DGL_SRC_1})
if(USE_CUDA) if(USE_CUDA)
dgl_config_cuda(DGL_CUDA_SRC) dgl_config_cuda(DGL_CUDA_SRC)
list(APPEND DGL_SRC ${DGL_CUDA_SRC}) list(APPEND DGL_SRC ${DGL_CUDA_SRC})
include(cmake/util/FindNccl.cmake)
include_directories(${NCCL_INCLUDE_DIR})
list(APPEND DGL_LINKER_LIBS ${NCCL_LIBRARY})
endif(USE_CUDA) endif(USE_CUDA)
if(USE_CUDA) if(USE_CUDA)
......
...@@ -235,6 +235,7 @@ macro(dgl_config_cuda out_variable) ...@@ -235,6 +235,7 @@ macro(dgl_config_cuda out_variable)
src/array/cuda/*.cu src/array/cuda/*.cu
src/kernel/cuda/*.cc src/kernel/cuda/*.cc
src/kernel/cuda/*.cu src/kernel/cuda/*.cu
src/partition/cuda/*.cu
src/runtime/cuda/*.cc src/runtime/cuda/*.cc
src/runtime/cuda/*.cu src/runtime/cuda/*.cu
src/geometry/cuda/*.cu src/geometry/cuda/*.cu
......
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Tries to find NCCL headers and libraries.
#
# Usage of this module as follows:
#
# find_package(NCCL)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# NCCL_ROOT - When set, this path is inspected instead of standard library
# locations as the root of the NCCL installation.
# The environment variable NCCL_ROOT overrides this variable.
#
# This module defines
# Nccl_FOUND, whether nccl has been found
# NCCL_INCLUDE_DIR, directory containing header
# NCCL_LIBRARY, directory containing nccl library
# NCCL_LIB_NAME, nccl library name
# USE_NCCL_LIB_PATH, when set, NCCL_LIBRARY path is also inspected for the
# location of the nccl library. This would disable
# switching between static and shared.
#
# This module assumes that the user has already called find_package(CUDA)
#
# This file is from https://github.com/dmlc/xgboost, with modifications to
# check the version.
if (NCCL_LIBRARY)
if(NOT USE_NCCL_LIB_PATH)
# Don't cache NCCL_LIBRARY to enable switching between static and shared.
unset(NCCL_LIBRARY CACHE)
endif(NOT USE_NCCL_LIB_PATH)
endif()
if (BUILD_WITH_SHARED_NCCL)
# libnccl.so
set(NCCL_LIB_NAME nccl)
else ()
# libnccl_static.a
set(NCCL_LIB_NAME nccl_static)
endif (BUILD_WITH_SHARED_NCCL)
find_path(NCCL_INCLUDE_DIR
NAMES nccl.h
PATHS $ENV{NCCL_ROOT}/include ${NCCL_ROOT}/include)
# make sure it has point to point support
file(STRINGS "${NCCL_INCLUDE_DIR}/nccl.h" NCCL_VERSION_CODE REGEX "^#define[ \t]+NCCL_VERSION_CODE[ \t]+[0-9]+.*$" LIMIT_COUNT 1)
string(REGEX REPLACE "^.*NCCL_VERSION_CODE[ \t]+([0-9]+).*$" "\\1" NCCL_VERSION "${NCCL_VERSION_CODE}")
find_library(NCCL_LIBRARY
NAMES ${NCCL_LIB_NAME}
PATHS $ENV{NCCL_ROOT}/lib/ ${NCCL_ROOT}/lib)
if ("${NCCL_VERSION}" LESS "2700")
message(FATAL_ERROR "Require nccl >= 2700, but found ${NCCL_LIBRARY}==${NCCL_VERSION}")
else()
message(STATUS "Using nccl library: ${NCCL_LIBRARY} ${NCCL_VERSION}")
endif()
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Nccl DEFAULT_MSG
NCCL_INCLUDE_DIR NCCL_LIBRARY)
mark_as_advanced(
NCCL_INCLUDE_DIR
NCCL_LIBRARY
)
...@@ -20,6 +20,7 @@ from . import random ...@@ -20,6 +20,7 @@ from . import random
from . import sampling from . import sampling
from . import dataloading from . import dataloading
from . import ops from . import ops
from . import cuda
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
""" CUDA wrappers """
from . import nccl
"""API creating NCCL communicators."""
from .. import backend as F
from .._ffi.function import _init_api
_COMM_MODES_MAP = {
'remainder': 0
}
class UniqueId(object):
""" Class for allowing python code to create and communicate NCCL Unique
IDs, needed for creating communicators.
"""
def __init__(self, id_str=None):
""" Create an object reference the current NCCL unique id.
"""
if id_str:
if isinstance(id_str, bytes):
id_str = id_str.decode('utf-8')
self._handle = _CAPI_DGLNCCLUniqueIdFromString(id_str)
else:
self._handle = _CAPI_DGLNCCLGetUniqueId()
def get(self):
""" Get the C-handle for this object.
"""
return self._handle
def __str__(self):
return _CAPI_DGLNCCLUniqueIdToString(self._handle)
def __repr__(self):
return "UniqueId[{}]".format(str(self))
def __eq__(self, other):
return str(self) == str(other)
class Communicator(object):
""" High-level wrapper for NCCL communication.
"""
def __init__(self, size, rank, unique_id):
""" Create a new NCCL communicator.
Parameters
----------
size : int
The number of processes in the communicator.
rank : int
The rank of the current process in the communicator.
unique_id : NCCLUniqueId
The unique id of the root process (rank=0).
Examples
--------
>>> from dgl.cuda.nccl import Communicator, UniqueId
The root process will generate a unique NCCL id and communicate it
to the other processes.
>>> uid = UniqueId()
>>> store.set('nccl_root_id', str(uid))
And all other processes create unique ids from the root processes.
>>> uid = UniqueId(store.get('nccl_root_id'))
Then, all processes should create the communicator.
>>> comm = Communicator(world_size, rank, uid)
"""
assert rank < size, "The rank of a process must be less than the " \
"size of the communicator."
self._handle = _CAPI_DGLNCCLCreateComm(size, rank, unique_id.get())
self._rank = rank
self._size = size
def sparse_all_to_all_push(self, idx, value, partition):
""" Perform an all-to-all-v operation, where by all processors send out
a set of indices and corresponding values. Indices and values,
corresponding to the current process, will copied into the output
arrays.
Parameters
----------
idx : tensor
The 1D set of indices to send to other processors.
value : tensor
The multi-dimension set of values to send to other processors.
The 0th dimension must match that of `idx`.
partition : NDArrayPartition
The object containing information for assigning indices to
processors.
Returns
-------
tensor
The 1D tensor of the recieved indices.
tensor
The set of recieved values.
Examples
--------
To perform a sparse_all_to_all_push(), a partition object must be
provided. A partition of a homgeonous graph, where the vertices are
striped across processes can be generated via:
>>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' )
With this partition, each processor can send values to be associatd
with vertices in the graph. So if we have an array `global_idxs` of all of
the neighbors updated during mini-batch processing, and an array
`global_values` containing the new values associated with the neighbors,
we communicate them to the own processes via:
>>> my_idxs, my_values = comm.sparse_all_to_all_push(global_idxs, global_values, part)
This communication pattern is common when communicating gradient
updates for node embeddings.
Indices the current process owns, do not need to treated specially,
as internally they will be copied to the output array. If we have a
set of indices in process 0 '[0, 3, 8, 9, 10]` and for process 1
'[0, 2, 4, 5, 8, 8, 9]'. Using a remainder partition will result
indices for processe 0 of '[0, 8, 10, 0, 2, 4, 8, 8]', and for
process 1 of '[3, 9, 5, 9]'.
"""
out_idx, out_value = _CAPI_DGLNCCLSparseAllToAllPush(
self.get(), F.zerocopy_to_dgl_ndarray(idx),
F.zerocopy_to_dgl_ndarray(value),
partition.get())
return (F.zerocopy_from_dgl_ndarray(out_idx),
F.zerocopy_from_dgl_ndarray(out_value))
def sparse_all_to_all_pull(self, req_idx, value, partition):
""" Perform an all-to-all-v operation, where by all processors request
the values corresponding to ther set of indices.
Parameters
----------
req_idx : IdArray
The set of indices this processor is requesting.
value : NDArray
The multi-dimension set of values that can be requested from
this processor.
partition : NDArrayPartition
The object containing information for assigning indices to
processors.
Returns
-------
tensor
The set of recieved values, corresponding to `req_idx`.
Examples
--------
To perform a sparse_all_to_all_pull(), a partition object must be
provided. A partition of a homgeonous graph, where the vertices are
striped across processes can be generated via:
>>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' )
With this partition, each processor can request values/features
associated with vertices in the graph. So in the case where we have
a set of neighbors 'nbr_idxs' we need features for, and each process
has a tensor 'node_feat' storing the features of nodes it owns in
the partition, the features can be requested via:
>>> nbr_values = comm.sparse_all_to_all_pull(nbr_idxs, node_feat, part)
Then two the arrays 'nbr_idxs' and 'nbr_values' forms the sparse
set of features, where 'nbr_idxs[i]' is the global node id, and
'nbr_values[i]' is the feature vector for that node. This
communication pattern is useful for node features or node
embeddings.
"""
out_value = _CAPI_DGLNCCLSparseAllToAllPull(
self.get(), F.zerocopy_to_dgl_ndarray(req_idx),
F.zerocopy_to_dgl_ndarray(value),
partition.get())
return F.zerocopy_from_dgl_ndarray(out_value)
def get(self):
""" Get the C-Handle for this object.
"""
return self._handle
def rank(self):
""" Get the rank of this process in this communicator.
Returns
-------
int
The rank of this process.
"""
return self._rank
def size(self):
""" Get the size of this communicator.
Returns
-------
int
The number of processes in this communicator.
"""
return self._size
_init_api("dgl.cuda.nccl")
...@@ -377,4 +377,48 @@ def metis_partition(g, k, extra_cached_hops=0, reshuffle=False, ...@@ -377,4 +377,48 @@ def metis_partition(g, k, extra_cached_hops=0, reshuffle=False,
# Then we split the original graph into parts based on the METIS partitioning results. # Then we split the original graph into parts based on the METIS partitioning results.
return partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle)[0] return partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle)[0]
class NDArrayPartition(object):
""" Create a new partition of an NDArray. That is, an object which assigns
each row of an NDArray to a specific partition.
Parameters
----------
array_size : int
The first dimension of the array being partitioned.
num_parts : int
The number of parts to divide the array into.
mode : String
The type of partition. Currently, the only valid value is 'remainder',
which assigns rows based on remainder when dividing the row id by the
number of parts (e.g., i % num_parts).
part_ranges : List
Currently unused.
Examples
--------
A partition of a homgeonous graph `g`, where the vertices are
striped across processes can be generated via:
>>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='remainder' )
"""
def __init__(self, array_size, num_parts, mode='remainder', part_ranges=None):
assert num_parts > 0, 'Invalid "num_parts", must be > 0.'
if mode == 'remainder':
assert part_ranges is None, 'When using remainder-based ' \
'partitioning, "part_ranges" should not be specified.'
self._partition = _CAPI_DGLNDArrayPartitionCreateRemainderBased(
array_size, num_parts)
else:
assert False, 'Unknown partition mode "{}"'.format(mode)
def get(self):
""" Get the C-handle for this object.
"""
return self._partition
_init_api("dgl.partition") _init_api("dgl.partition")
...@@ -114,8 +114,6 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -114,8 +114,6 @@ NDArray IndexSelect(NDArray array, IdArray index) {
NDArray ret; NDArray ret;
CHECK_SAME_CONTEXT(array, index); CHECK_SAME_CONTEXT(array, index);
CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension"; CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array."; CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", { ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "IndexSelect", {
ATEN_DTYPE_SWITCH(array->dtype, DType, "values", { ATEN_DTYPE_SWITCH(array->dtype, DType, "values", {
......
...@@ -12,6 +12,9 @@ namespace impl { ...@@ -12,6 +12,9 @@ namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType> template<DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./array_index_select.cuh"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
...@@ -12,17 +13,6 @@ using runtime::NDArray; ...@@ -12,17 +13,6 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(const DType* array, const IdType* index,
int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}
template<DLDeviceType XPU, typename DType, typename IdType> template<DLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
...@@ -30,14 +20,33 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -30,14 +20,33 @@ NDArray IndexSelect(NDArray array, IdArray index) {
const IdType* idx_data = static_cast<IdType*>(index->data); const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx); int64_t num_feat = 1;
std::vector<int64_t> shape{len};
for (int d = 1; d < array->ndim; ++d) {
num_feat *= array->shape[d];
shape.emplace_back(array->shape[d]);
}
NDArray ret = NDArray::Empty(shape, array->dtype, array->ctx);
if (len == 0) if (len == 0)
return ret; return ret;
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream,
array_data, idx_data, len, ret_data); array_data, idx_data, len, ret_data);
} else {
dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) {
block.x /= 2;
block.y *= 2;
}
const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream,
array_data, num_feat, idx_data, len, ret_data);
}
return ret; return ret;
} }
......
/*!
* Copyright (c) 2021 by Contributors
* \file array/cpu/array_index_select.cuh
* \brief Array index select GPU kernel implementation
*/
#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_CUH_
namespace dgl {
namespace aten {
namespace impl {
template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel(const DType* array, const IdType* index,
int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}
template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
const DType* const array,
const int64_t num_feat,
const IdType* const index,
const int64_t length,
DType* const out) {
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
const int64_t stride = blockDim.y*gridDim.x;
while (out_row < length) {
int64_t col = threadIdx.x;
const int64_t in_row = index[out_row];
while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x;
}
out_row += stride;
}
}
} // namespace impl
} // namespace aten
} // namespace dgl
#endif
/*!
* Copyright (c) 2021 by Contributors
* \file ndarray_partition.h
* \brief Operations on partition implemented in CUDA.
*/
#include "../partition_op.h"
#include <dgl/runtime/device_api.h>
#include "../../array/cuda/dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/workspace.h"
using namespace dgl::runtime;
namespace dgl {
namespace partition {
namespace impl {
template<typename IdType> __global__ void _MapProcByRemainder(
const IdType * const index,
const int64_t num_index,
const int64_t num_proc,
IdType * const proc_id) {
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_index) {
proc_id[idx] = index[idx] % num_proc;
}
}
template<typename IdType>
__global__ void _MapProcByMaskRemainder(
const IdType * const index,
const int64_t num_index,
const IdType mask,
IdType * const proc_id) {
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_index) {
proc_id[idx] = index[idx] & mask;
}
}
template<typename IdType>
__global__ void _MapLocalIndexByRemainder(
const IdType * const in,
IdType * const out,
const int64_t num_items,
const int comm_size) {
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
if (idx < num_items) {
out[idx] = in[idx] / comm_size;
}
}
template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray>
GeneratePermutationFromRemainder(
int64_t array_size,
int num_parts,
IdArray in_idx) {
std::pair<IdArray, NDArray> result;
const auto& ctx = in_idx->ctx;
auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
const int64_t num_in = in_idx->shape[0];
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts <<
") must be at least 1.";
if (num_parts == 1) {
// no permutation
result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx);
return result;
}
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) {
// now that we've zero'd out_counts, nothing left to do for an empty
// mapping
return result;
}
const int64_t part_bits =
static_cast<int64_t>(std::ceil(std::log2(num_parts)));
// First, generate a mapping of indexes to processors
Workspace<IdType> proc_id_in(device, ctx, num_in);
{
const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x);
if (num_parts < (1 << part_bits)) {
// num_parts is not a power of 2
CUDA_KERNEL_CALL(_MapProcByRemainder, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data),
num_in,
num_parts,
proc_id_in.get());
} else {
// num_parts is a power of 2
CUDA_KERNEL_CALL(_MapProcByMaskRemainder, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data),
num_in,
static_cast<IdType>(num_parts-1), // bit mask
proc_id_in.get());
}
}
// then create a permutation array that groups processors together by
// performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in);
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
IdType * perm_out = static_cast<IdType*>(result.first->data);
{
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
size_t sort_workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
}
// explicitly free so workspace can be re-used
proc_id_in.free();
// perform a histogram and then prefixsum on the sorted proc_id vector
// Count the number of values to be sent to each processor
{
using AtomicCount = unsigned long long; // NOLINT
static_assert(sizeof(AtomicCount) == sizeof(*out_counts),
"AtomicCount must be the same width as int64_t for atomicAdd "
"in cub::DeviceHistogram::HistogramEven() to work");
// TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
// add a compile time check against the cub version to allow
// num_in > (2 << 31).
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) <<
"number of values to insert into histogram must be less than max "
"value of int.";
size_t hist_workspace_size;
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr,
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(),
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
}
return result;
}
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int32_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder<kDLGPU, int64_t>(
int64_t array_size,
int num_parts,
IdArray in_idx);
template <DLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder(
const int num_parts,
IdArray global_idx) {
const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
sizeof(IdType)*8);
const dim3 block(128);
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x);
CUDA_KERNEL_CALL(
_MapLocalIndexByRemainder,
grid,
block,
0,
stream,
static_cast<const IdType*>(global_idx->data),
static_cast<IdType*>(local_idx->data),
global_idx->shape[0],
num_parts);
return local_idx;
}
template IdArray
MapToLocalFromRemainder<kDLGPU, int32_t>(
int num_parts,
IdArray in_idx);
template IdArray
MapToLocalFromRemainder<kDLGPU, int64_t>(
int num_parts,
IdArray in_idx);
} // namespace impl
} // namespace partition
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file ndarray_partition.cc
* \brief DGL utilities for working with the partitioned NDArrays
*/
#include "ndarray_partition.h"
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include <utility>
#include <memory>
#include "partition_op.h"
using namespace dgl::runtime;
namespace dgl {
namespace partition {
NDArrayPartition::NDArrayPartition(
const int64_t array_size, const int num_parts) :
array_size_(array_size),
num_parts_(num_parts) {
}
int64_t NDArrayPartition::ArraySize() const {
return array_size_;
}
int NDArrayPartition::NumParts() const {
return num_parts_;
}
class RemainderPartition : public NDArrayPartition {
public:
RemainderPartition(
const int64_t array_size, const int num_parts) :
NDArrayPartition(array_size, num_parts) {
// do nothing
}
std::pair<IdArray, NDArray>
GeneratePermutation(
IdArray in_idx) const override {
auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
if (ctx.device_type == kDLGPU) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::GeneratePermutationFromRemainder<kDLGPU, IdType>(
ArraySize(), NumParts(), in_idx);
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return std::pair<IdArray, NDArray>{};
}
IdArray MapToLocal(
IdArray in_idx) const override {
auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
if (ctx.device_type == kDLGPU) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToLocalFromRemainder<kDLGPU, IdType>(
NumParts(), in_idx);
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
};
NDArrayPartitionRef CreatePartitionRemainderBased(
const int64_t array_size,
const int num_parts) {
return NDArrayPartitionRef(std::make_shared<RemainderPartition>(
array_size, num_parts));
}
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t array_size = args[0];
int num_parts = args[1];
*rv = CreatePartitionRemainderBased(array_size, num_parts);
});
} // namespace partition
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file ndarray_partition.h
* \brief DGL utilities for working with the partitioned NDArrays
*/
#ifndef DGL_PARTITION_NDARRAY_PARTITION_H_
#define DGL_PARTITION_NDARRAY_PARTITION_H_
#include <dgl/runtime/object.h>
#include <dgl/array.h>
#include <utility>
namespace dgl {
namespace partition {
/**
* @brief The top-level partition class. Specific types of partitions should be
* sub-classes of this.
*/
class NDArrayPartition : public runtime::Object {
public:
/**
* @brief Create a new partition.
*
* @param array_size The first dimension of the partitioned array.
* @param num_parts The number parts to the array is split into.
*/
NDArrayPartition(
int64_t array_size,
int num_parts);
virtual ~NDArrayPartition() = default;
static constexpr const char* _type_key = "partition.NDArrayPartition";
DGL_DECLARE_OBJECT_TYPE_INFO(NDArrayPartition, Object);
/**
* @brief Create a mapping for the given indices to different partitions,
* and a count of the number of indices per part.
*
* A prefix-sum of the counts, can be used to select the continuous sets of
* indices destined for each part.
*
* @param in_idx The input indices to map.
*
* @return A pair containing 0) the permutation to re-order the indices by
* partition, 1) the number of indices per partition (int64_t).
*/
virtual std::pair<IdArray, NDArray>
GeneratePermutation(
IdArray in_idx) const = 0;
/**
* @brief Generate the local indices (the numbering within each processor)
* from a set of global indices.
*
* @param in_idx The global indices.
*
* @return The local indices.
*/
virtual IdArray MapToLocal(
IdArray in_idx) const = 0;
/**
* @brief Get the first dimension of the partitioned array.
*
* @return The size.
*/
int64_t ArraySize() const;
/**
* @brief Get the number of parts in this partition.
*
* @return The number of parts.
*/
int NumParts() const;
private:
int64_t array_size_;
int num_parts_;
};
DGL_DEFINE_OBJECT_REF(NDArrayPartitionRef, NDArrayPartition);
/**
* @brief Create a new partition object, using the remainder of the row id
* divided by the number of parts, to assign rows to parts.
*
* @param array_size The first dimension of the array.
* @param num_parts The number of parts.
*
* @return The partition object.
*/
NDArrayPartitionRef CreatePartitionRemainderBased(
int64_t array_size,
int num_parts);
} // namespace partition
} // namespace dgl
#endif // DGL_PARTITION_NDARRAY_PARTITION_H_
/*!
* Copyright (c) 2021 by Contributors
* \file ndarray_partition.h
* \brief DGL utilities for working with the partitioned NDArrays
*/
#ifndef DGL_PARTITION_PARTITION_OP_H_
#define DGL_PARTITION_PARTITION_OP_H_
#include <dgl/array.h>
#include <utility>
namespace dgl {
namespace partition {
namespace impl {
/**
* @brief Create a permutation that groups indices by the part id.
*
* @tparam XPU The type of device to run on.
* @tparam IdType The type of the index.
* @param array_size The total size of the partitioned array.
* @param num_parts The number parts the array id divided into.
* @param in_idx The array of indices to group by part id.
*
* @return The permutation to group the indices by part id, and the number of
* indices in each part.
*/
template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray>
GeneratePermutationFromRemainder(
int64_t array_size,
int num_parts,
IdArray in_idx);
/**
* @brief Generate the set of local indices from the global indices, using
* remainder. That is, for each index `i` in `global_idx`, the local index
* is computed as `global_idx[i] / num_parts`.
*
* @tparam XPU The type of device to run on.
* @tparam IdType The type of the index.
* @param num_parts The number parts the array id divided into.
* @param global_idx The array of global indices to map.
*
* @return The array of local indices.
*/
template <DLDeviceType XPU, typename IdType>
IdArray MapToLocalFromRemainder(
int num_parts,
IdArray global_idx);
} // namespace impl
} // namespace partition
} // namespace dgl
#endif // DGL_PARTITION_PARTITION_OP_H_
...@@ -190,10 +190,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -190,10 +190,10 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t size, size_t size,
cudaMemcpyKind kind, cudaMemcpyKind kind,
cudaStream_t stream) { cudaStream_t stream) {
if (stream != 0) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else { if (stream == 0 && kind == cudaMemcpyDeviceToHost) {
CUDA_CALL(cudaMemcpy(to, from, size, kind)); // only wait for the copy, when it's on the default stream, and it's to host memory
CUDA_CALL(cudaStreamSynchronize(stream));
} }
} }
}; };
......
/*!
* Copyright (c) 2021 by Contributors
* \file nccl_api.cc
* \brief Implementation of wrapper around NCCL routines.
*/
#include "nccl_api.h"
#include <dgl/array.h>
#include <dgl/aten/array_ops.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/device_api.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cmath>
#include <sstream>
#include <iomanip>
#include <utility>
#include <vector>
#include <memory>
#include <string>
#include <limits>
#include "cuda_common.h"
#include "../../runtime/workspace.h"
#include "../../partition/ndarray_partition.h"
#include "../../kernel/cuda/atomic.cuh"
#include "../../array/cuda/dgl_cub.cuh"
#include "../../array/cuda/array_index_select.cuh"
#define NCCL_CALL(func) \
{ \
ncclResult_t result = func; \
if (result != ncclSuccess) { \
LOG(FATAL) \
<< "NCCLError: " #func " failed with error: " << result; \
} \
}
namespace dgl {
using namespace kernel::cuda;
using namespace partition;
namespace runtime {
namespace cuda {
namespace {
enum class AllToAllMode : int {
REMAINDER = 0
};
template<typename T> ncclDataType_t NCCLType();
template<> ncclDataType_t NCCLType<int32_t>() {
return ncclInt32;
}
template<> ncclDataType_t NCCLType<int64_t>() {
return ncclInt64;
}
template<> ncclDataType_t NCCLType<__half>() {
return ncclHalf;
}
template<> ncclDataType_t NCCLType<float>() {
return ncclFloat32;
}
template<> ncclDataType_t NCCLType<double>() {
return ncclFloat64;
}
template<typename IdType, typename DType>
__global__ void _DualPermKernel(
const IdType * const in_idx,
const DType * const in_value,
const IdType * const perm,
const int64_t num_in,
const int64_t num_feat,
IdType * const out_idx,
DType * const out_value) {
// set index permutation
const int64_t tidx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (tidx < num_in) {
const IdType perm_idx = perm[tidx];
assert(perm_idx < num_in);
out_idx[tidx] = in_idx[perm_idx];
}
if (num_feat > 1) {
for (int d = 0; d < blockDim.x; ++d) {
const int64_t bidx = blockDim.x*static_cast<int64_t>(blockIdx.x) + d;
if (bidx < num_in) {
const IdType perm_idx = perm[bidx];
for (int64_t f = threadIdx.x; f < num_feat; f+=blockDim.x) {
out_value[bidx*num_feat+f] = in_value[perm_idx*num_feat+f];
}
}
}
} else {
if (tidx < num_in) {
const IdType perm_idx = perm[tidx];
out_value[tidx] = in_value[perm_idx];
}
}
}
template <typename DType, typename IdType>
__global__ void _InversePermKernel(
const DType* const array,
const int64_t num_feat,
int64_t length,
const IdType* const perm,
DType* const out) {
int64_t in_row = blockIdx.x*blockDim.y+threadIdx.y;
const int64_t stride = blockDim.y*gridDim.x;
while (in_row < length) {
int64_t col = threadIdx.x;
const int64_t out_row = perm[in_row];
while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x;
}
in_row += stride;
}
}
template<typename IdType, typename DType>
std::pair<IdArray, NDArray> SparsePush(
NCCLCommunicatorRef comm,
IdArray in_idx,
NDArray in_value,
NDArrayPartitionRef part) {
CHECK_EQ(in_idx->shape[0], in_value->shape[0]) <<
"Leading dimension of indices (" << in_idx->shape[0] << ") must match "
"leading dimension of values (" << in_value->shape[0] << ").";
const auto& ctx = in_idx->ctx;
CHECK_EQ(ctx, in_value->ctx) << "Indices and values must be on the same "
"device";
auto device = DeviceAPI::Get(ctx);
// TODO(dlasalle): Get the stream from the device context.
cudaStream_t stream = 0;
CHECK_EQ(in_idx->ndim, 1) << "Indices must be 1-dimensional";
const int64_t num_in = in_idx->shape[0];
int64_t num_feat = 1;
for (int d = 1; d < in_value->ndim; ++d) {
num_feat *= in_value->shape[d];
}
const int64_t comm_size = comm->size();
if (comm_size == 1) {
// nothing to do, just return original arrays
return std::pair<IdArray, NDArray>(in_idx, in_value);
}
std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(in_idx);
const IdType * const perm = static_cast<const IdType*>(part_perm.first->data);
const int64_t * const send_sum =
static_cast<const int64_t*>(part_perm.second->data);
Workspace<IdType> send_idx(device, ctx, num_in);
Workspace<DType> send_value(device, ctx, num_in*num_feat);
// permute the indices and values
{
const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x);
_DualPermKernel<<<grid, block, 0, stream>>>(
static_cast<const IdType*>(in_idx->data),
static_cast<const DType*>(in_value->data),
perm,
num_in,
num_feat,
send_idx.get(),
send_value.get());
CUDA_CALL(cudaGetLastError());
}
// compute the prefix sum of the send values
Workspace<int64_t> send_prefix(device, ctx, comm_size+1);
{
size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size,
send_sum, send_prefix.get(), comm_size+1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(),
prefix_workspace_size, send_sum, send_prefix.get(),
comm_size+1, stream));
}
std::vector<int64_t> send_prefix_host(comm_size+1);
device->CopyDataFromTo(
send_prefix.get(),
0,
send_prefix_host.data(),
0,
send_prefix_host.size()*sizeof(*send_prefix.get()),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*send_prefix.get())*8, 1},
stream);
send_prefix.free();
CHECK_EQ(send_prefix_host.back(), num_in) << "Internal Error: "
"send_prefix_host.back() = " << send_prefix_host.back() <<
", and num_in = " << num_in;
// communicate the amount to send
Workspace<int64_t> recv_sum(device, ctx, comm_size+1);
comm->AllToAll(send_sum, recv_sum.get(), 1, stream);
cudaEvent_t d2h;
cudaEventCreate(&d2h);
// compute the prefix sum of the recv values
Workspace<int64_t> recv_prefix(device, ctx, comm_size+1);
{
size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size,
recv_sum.get(), recv_prefix.get(), comm_size+1));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(),
prefix_workspace_size, recv_sum.get(), recv_prefix.get(), comm_size+1));
}
recv_sum.free();
// finally copy the prefixsum sum down to the host
std::vector<int64_t> recv_prefix_host(comm_size+1);
device->CopyDataFromTo(
recv_prefix.get(),
0,
recv_prefix_host.data(),
0,
recv_prefix_host.size()*sizeof(*recv_prefix.get()),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*recv_prefix.get())*8, 1},
stream);
recv_prefix.free();
// use an event to track when copying is done
cudaEventRecord(d2h, stream);
// allocate output space
cudaEventSynchronize(d2h);
cudaEventDestroy(d2h);
IdArray recv_idx = aten::NewIdArray(
recv_prefix_host.back(), ctx, sizeof(IdType)*8);
std::vector<int64_t> value_shape(in_value->ndim, 0);
value_shape[0] = recv_prefix_host.back();
for (int d = 1; d < in_value->ndim; ++d) {
value_shape[d] = in_value->shape[d];
}
NDArray recv_value = NDArray::Empty(value_shape, in_value->dtype, ctx);
// send data
comm->SparseAllToAll(
send_idx.get(),
send_value.get(),
num_feat,
send_prefix_host.data(),
static_cast<IdType*>(recv_idx->data),
static_cast<DType*>(recv_value->data),
recv_prefix_host.data(),
stream);
return std::pair<IdArray, NDArray>(recv_idx, recv_value);
}
template<typename IdType, typename DType>
NDArray SparsePull(
NCCLCommunicatorRef comm,
IdArray req_idx,
NDArray local_tensor,
NDArrayPartitionRef part) {
const auto& ctx = req_idx->ctx;
CHECK_EQ(ctx, local_tensor->ctx) << "The request indices and set of local "
"values must be on the same device";
auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
CHECK_EQ(req_idx->ndim, 1) << "The tensor of requested indices must be of "
"dimension one.";
const int64_t num_in = req_idx->shape[0];
int64_t num_feat = 1;
for (int d = 1; d < local_tensor->ndim; ++d) {
num_feat *= local_tensor->shape[d];
}
const int64_t comm_size = comm->size();
if (comm_size == 1) {
// Just return index selection from current local_tensor
return aten::IndexSelect(local_tensor, req_idx);
}
// First we need to send our requests to other processors. This means
// re-ordering our index array to be contiguous among processors, and
// counting the number of indices we are sending each processor. For now,
// we assume a poorly partitioned graph, and that there exists the
// possibility that each processor could request data from this one.
// the buffer for us to re-order our requests in
Workspace<IdType> send_idx(device, ctx, num_in);
std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(req_idx);
const IdType * const perm = static_cast<const IdType*>(part_perm.first->data);
const int64_t * const send_sum =
static_cast<const int64_t*>(part_perm.second->data);
// permute requests
{
const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x);
aten::impl::IndexSelectSingleKernel<<<grid, block, 0, stream>>>(
static_cast<const IdType*>(req_idx->data),
perm,
num_in,
send_idx.get());
CUDA_CALL(cudaGetLastError());
}
// compute the prefix sum of the indexes this process is requesting
Workspace<int64_t> request_prefix(device, ctx, comm_size+1);
{
size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size,
send_sum, request_prefix.get(), comm_size+1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(),
prefix_workspace_size, send_sum, request_prefix.get(),
comm_size+1, stream));
}
cudaEvent_t d2h;
cudaEventCreate(&d2h);
std::vector<int64_t> request_prefix_host(comm_size+1);
device->CopyDataFromTo(
request_prefix.get(),
0,
request_prefix_host.data(),
0,
request_prefix_host.size()*sizeof(*request_prefix.get()),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*request_prefix.get())*8, 1},
stream);
request_prefix.free();
CHECK_EQ(request_prefix_host.back(), num_in) << "Internal Error: "
"request_prefix_host.back() = " << request_prefix_host.back() <<
", num_in = " << num_in;
// communicate the amount requested
Workspace<int64_t> recv_sum(device, ctx, comm_size+1);
comm->AllToAll(send_sum, recv_sum.get(), 1, stream);
// compute the prefix sum of the requested indexes
Workspace<int64_t> response_prefix(device, ctx, comm_size+1);
{
size_t prefix_workspace_size;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_workspace_size,
recv_sum.get(), response_prefix.get(), comm_size+1, stream));
Workspace<void> prefix_workspace(device, ctx, prefix_workspace_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_workspace.get(),
prefix_workspace_size, recv_sum.get(), response_prefix.get(),
comm_size+1, stream));
}
recv_sum.free();
// finally copy the prefixsum sum down to the host
std::vector<int64_t> response_prefix_host(comm_size+1);
device->CopyDataFromTo(
response_prefix.get(),
0,
response_prefix_host.data(),
0,
response_prefix_host.size()*sizeof(*response_prefix.get()),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, sizeof(*response_prefix.get())*8, 1},
stream);
response_prefix.free();
// use an event to track when copying is done
cudaEventRecord(d2h, stream);
// allocate output space
cudaEventSynchronize(d2h);
cudaEventDestroy(d2h);
// gather requested indexes
IdArray recv_idx = aten::NewIdArray(
response_prefix_host.back(), ctx, sizeof(IdType)*8);
comm->AllToAllV(
send_idx.get(),
request_prefix_host.data(),
static_cast<IdType*>(recv_idx->data),
response_prefix_host.data(),
stream);
send_idx.free();
// convert requested indices to local indices depending on partition
if (response_prefix_host.back() > 0) {
recv_idx = part->MapToLocal(recv_idx);
}
// and then index select them into place
Workspace<DType> filled_response_value(device, ctx,
response_prefix_host.back()*num_feat);
if (request_prefix_host.back() > 0) {
dim3 block(256, 1);
while (block.x >= 2*num_feat) {
block.x /= 2;
block.y *= 2;
}
const dim3 grid((request_prefix_host.back()+block.y-1)/block.y);
aten::impl::IndexSelectMultiKernel<<<grid, block, 0, stream>>>(
static_cast<const DType*>(local_tensor->data),
num_feat,
static_cast<IdType*>(recv_idx->data),
response_prefix_host.back(),
filled_response_value.get());
CUDA_CALL(cudaGetLastError());
}
// we will collect recieved values in this array
std::vector<int64_t> value_shape(local_tensor->ndim, 0);
value_shape[0] = request_prefix_host.back();
for (int d = 1; d < local_tensor->ndim; ++d) {
value_shape[d] = local_tensor->shape[d];
}
Workspace<DType> filled_request_value(device, ctx,
request_prefix_host.back()*num_feat);
// multiply the prefixes by the number of features being sent
for (auto& v : request_prefix_host) {
v *= num_feat;
}
for (auto& v : response_prefix_host) {
v *= num_feat;
}
// send the values
comm->AllToAllV(
filled_response_value.get(),
response_prefix_host.data(),
filled_request_value.get(),
request_prefix_host.data(),
stream);
filled_response_value.free();
// finally, we need to permute the values back into the requested order
NDArray result = NDArray::Empty(value_shape, local_tensor->dtype, ctx);
if (num_in > 0) {
dim3 block(256, 1);
while (block.x >= 2*num_feat) {
block.x /= 2;
block.y *= 2;
}
const dim3 grid((num_in+block.y-1)/block.y);
_InversePermKernel<<<grid, block, 0, stream>>>(
filled_request_value.get(),
num_feat,
num_in,
perm,
static_cast<DType*>(result->data));
CUDA_CALL(cudaGetLastError());
}
return result;
}
} // namespace
/* NCCLUniqueId **************************************************************/
NCCLUniqueId::NCCLUniqueId() :
id_() {
// this ID is unique to the process, not to each call of this function
NCCL_CALL(ncclGetUniqueId(&id_));
}
ncclUniqueId NCCLUniqueId::Get() const {
return id_;
}
std::string NCCLUniqueId::ToString() const {
std::ostringstream oss;
oss << std::hex;
for (size_t b = 0; b < NCCL_UNIQUE_ID_BYTES; ++b) {
const int num = static_cast<uint8_t>(id_.internal[b]);
oss << std::setw(2) << std::setfill('0') << num;
}
std::string result = oss.str();
CHECK_EQ(result.length(), NCCL_UNIQUE_ID_BYTES*2) <<
"Invalid NCCL ID format: '" << result << "'";
return result;
}
void NCCLUniqueId::FromString(
const std::string& str) {
// must be exactly 256 hex characters
CHECK_EQ(str.length(), NCCL_UNIQUE_ID_BYTES * 2) <<
"Invalid NCCL ID format: '" << str << "'";
for (size_t b = 0; b < NCCL_UNIQUE_ID_BYTES; ++b) {
id_.internal[b] = std::strtol(str.substr(b*2, 2).c_str(), nullptr, 16);
}
}
/* NCCLCommunicator **********************************************************/
NCCLCommunicator::NCCLCommunicator(
const int size,
const int rank,
ncclUniqueId id) :
comm_(),
size_(size),
rank_(rank) {
CHECK_LT(rank, size) << "The rank (" << rank << ") must be smaller than "
"the size of the communicator (" << size << ").";
CHECK_GE(rank, 0) << "The rank (" << rank << ") must be greater than or "
"equal to 0.";
NCCL_CALL(ncclCommInitRank(&comm_, size_, id, rank_));
}
NCCLCommunicator::~NCCLCommunicator() {
ncclCommDestroy(comm_);
}
ncclComm_t NCCLCommunicator::Get() {
return comm_;
}
template<typename DType>
void NCCLCommunicator::AllToAllV(
const DType * const send,
const int64_t * const send_prefix,
DType * const recv,
const int64_t * const recv_prefix,
cudaStream_t stream) {
const ncclDataType_t type = NCCLType<DType>();
NCCL_CALL(ncclGroupStart());
for (int r = 0; r < size_; ++r) {
const int64_t send_size = send_prefix[r+1]-send_prefix[r];
if (send_size > 0) {
NCCL_CALL(ncclSend(send+send_prefix[r], send_size, type, r, comm_, stream));
}
const int64_t recv_size = recv_prefix[r+1]-recv_prefix[r];
if (recv_size > 0) {
NCCL_CALL(ncclRecv(recv+recv_prefix[r], recv_size, type, r, comm_, stream));
}
}
NCCL_CALL(ncclGroupEnd());
}
template
void NCCLCommunicator::AllToAllV<int32_t>(
const int32_t * const send,
const int64_t * send_prefix,
int32_t * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
template
void NCCLCommunicator::AllToAllV<int64_t>(
const int64_t * const send,
const int64_t * send_prefix,
int64_t * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
template
void NCCLCommunicator::AllToAllV<float>(
const float * const send,
const int64_t * send_prefix,
float * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
template
void NCCLCommunicator::AllToAllV<__half>(
const __half * const send,
const int64_t * send_prefix,
__half * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
template<typename IdType>
void NCCLCommunicator::AllToAll(
const IdType * const send,
IdType * const recv,
const int64_t count,
cudaStream_t stream) {
const ncclDataType_t type = NCCLType<IdType>();
ncclGroupStart();
for (int r = 0; r < size_; ++r) {
ncclSend(send+(r*count), count, type, r, comm_, stream);
ncclRecv(recv+(r*count), count, type, r, comm_, stream);
}
ncclGroupEnd();
}
template
void NCCLCommunicator::AllToAll<int32_t>(
const int32_t * const send,
int32_t * const recv,
const int64_t count,
cudaStream_t stream);
template
void NCCLCommunicator::AllToAll<int64_t>(
const int64_t * const send,
int64_t * const recv,
const int64_t count,
cudaStream_t stream);
template<typename IdType, typename DType>
void NCCLCommunicator::SparseAllToAll(
const IdType * const send_idx,
const DType * const send_value,
const int64_t num_feat,
const int64_t * const send_prefix,
IdType * const recv_idx,
DType * const recv_value,
const int64_t * const recv_prefix,
cudaStream_t stream) {
const ncclDataType_t idx_type = NCCLType<IdType>();
const ncclDataType_t value_type = NCCLType<DType>();
ncclGroupStart();
for (int r = 0; r < size_; ++r) {
const int64_t send_size = send_prefix[r+1]-send_prefix[r];
if (send_size > 0) {
ncclSend(send_idx+send_prefix[r], send_size, idx_type, r, comm_, stream);
ncclSend(send_value+send_prefix[r]*num_feat, send_size*num_feat,
value_type, r, comm_, stream);
}
const int64_t recv_size = recv_prefix[r+1]-recv_prefix[r];
if (recv_size > 0) {
ncclRecv(recv_idx+recv_prefix[r], recv_size, idx_type, r, comm_, stream);
ncclRecv(recv_value+recv_prefix[r]*num_feat, recv_size*num_feat,
value_type, r, comm_, stream);
}
}
ncclGroupEnd();
}
template
void NCCLCommunicator::SparseAllToAll<int32_t, __half>(
const int32_t * const send_idx,
const __half * const send_value,
const int64_t num_feat,
const int64_t * const send_prefix,
int32_t * const recv_idx,
__half * const recv_value,
const int64_t * const recv_prefix,
cudaStream_t stream);
template
void NCCLCommunicator::SparseAllToAll<int64_t, __half>(
const int64_t * const send_idx,
const __half * const send_value,
const int64_t num_feat,
const int64_t * const send_prefix,
int64_t * const recv_idx,
__half * const recv_value,
const int64_t * const recv_prefix,
cudaStream_t stream);
int NCCLCommunicator::size() const {
return size_;
}
int NCCLCommunicator::rank() const {
return rank_;
}
/* CAPI **********************************************************************/
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLGetUniqueId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = NCCLUniqueIdRef(std::make_shared<NCCLUniqueId>());
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdToString")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NCCLUniqueIdRef idObj = args[0];
*rv = idObj->ToString();
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLUniqueIdFromString")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string str = args[0];
NCCLUniqueIdRef ref(std::make_shared<NCCLUniqueId>());
ref->FromString(str);
*rv = ref;
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLCreateComm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int size = args[0];
const int rank = args[1];
NCCLUniqueIdRef idObj = args[2];
*rv = NCCLCommunicatorRef(std::make_shared<NCCLCommunicator>(size, rank,
idObj->Get()));
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPush")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NCCLCommunicatorRef comm = args[0];
IdArray in_idx = args[1];
NDArray in_values = args[2];
NDArrayPartitionRef part = args[3];
List<ObjectRef> ret;
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_DTYPE_SWITCH(in_values->dtype, DType, "values", {
auto result = SparsePush<IdType, DType>(comm, in_idx, in_values, part);
ret.push_back(Value(MakeValue(result.first)));
ret.push_back(Value(MakeValue(result.second)));
});
});
*rv = ret;
});
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NCCLCommunicatorRef comm = args[0];
// the indexes this process is requesting from others
IdArray req_idx = args[1];
// the tensor this process has to fulfill other requests
NDArray tensor = args[2];
NDArrayPartitionRef part = args[3];
ATEN_ID_TYPE_SWITCH(req_idx->dtype, IdType, {
ATEN_DTYPE_SWITCH(tensor->dtype, DType, "values", {
*rv = SparsePull<IdType, DType>(comm, req_idx, tensor, part);
});
});
});
} // namespace cuda
} // namespace runtime
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file nccl_api.h
* \brief Wrapper around NCCL routines.
*/
#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_
#include "nccl.h"
#include <dgl/runtime/object.h>
#include <string>
namespace dgl {
namespace runtime {
namespace cuda {
class NCCLUniqueId : public runtime::Object {
public:
NCCLUniqueId();
static constexpr const char* _type_key = "cuda.NCCLUniqueId";
DGL_DECLARE_OBJECT_TYPE_INFO(NCCLUniqueId, Object);
ncclUniqueId Get() const;
std::string ToString() const;
void FromString(const std::string& str);
private:
ncclUniqueId id_;
};
DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId);
class NCCLCommunicator : public runtime::Object {
public:
NCCLCommunicator(
int size,
int rank,
ncclUniqueId id);
~NCCLCommunicator();
// disable copying
NCCLCommunicator(const NCCLCommunicator& other) = delete;
NCCLCommunicator& operator=(
const NCCLCommunicator& other);
ncclComm_t Get();
/**
* @brief Perform an all-to-all communication.
*
* @param send The continous array of data to send.
* @param recv The continous array of data to recieve.
* @param count The size of data to send to each rank.
* @param stream The stream to operate on.
*/
template<typename IdType>
void AllToAll(
const IdType * send,
IdType * recv,
int64_t count,
cudaStream_t stream);
/**
* @brief Perform an all-to-all variable sized communication.
*
* @tparam DType The type of value to send.
* @param send The arrays of data to send.
* @param send_prefix The prefix of each array to send.
* @param recv The arrays of data to recieve.
* @param recv_prefix The prefix of each array to recieve.
* @param type The type of data to send.
* @param stream The stream to operate on.
*/
template<typename DType>
void AllToAllV(
const DType * const send,
const int64_t * send_prefix,
DType * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
/**
* @brief Perform an all-to-all with sparse data (idx and value pairs). By
* necessity, the sizes of each message are variable.
*
* @tparam IdType The type of index.
* @tparam DType The type of value.
* @param send_idx The set of indexes to send on the device.
* @param send_value The set of values to send on the device.
* @param num_feat The number of values per index.
* @param send_prefix The exclusive prefix sum of elements to send on the
* host.
* @param recv_idx The set of indexes to recieve on the device.
* @param recv_value The set of values to recieve on the device.
* @param recv_prefix The exclusive prefix sum of the number of elements to
* recieve on the host.
* @param stream The stream to communicate on.
*/
template<typename IdType, typename DType>
void SparseAllToAll(
const IdType * send_idx,
const DType * send_value,
const int64_t num_feat,
const int64_t * send_prefix,
IdType * recv_idx,
DType * recv_value,
const int64_t * recv_prefix,
cudaStream_t stream);
int size() const;
int rank() const;
static constexpr const char* _type_key = "cuda.NCCLCommunicator";
DGL_DECLARE_OBJECT_TYPE_INFO(NCCLCommunicator, Object);
private:
ncclComm_t comm_;
int size_;
int rank_;
};
DGL_DEFINE_OBJECT_REF(NCCLCommunicatorRef, NCCLCommunicator);
} // namespace cuda
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_CUDA_NCCL_API_H_
/*!
* Copyright (c) 2021 by Contributors
* \file ndarray_partition.h
* \brief Operations on partition implemented in CUDA.
*/
#ifndef DGL_RUNTIME_WORKSPACE_H_
#define DGL_RUNTIME_WORKSPACE_H_
#include <dgl/runtime/device_api.h>
#include <cassert>
namespace dgl {
namespace runtime {
template<typename T>
class Workspace {
public:
Workspace(DeviceAPI* device, DGLContext ctx, const size_t size) :
device_(device),
ctx_(ctx),
ptr_(static_cast<T*>(device_->AllocWorkspace(ctx_, sizeof(T)*size))) {
}
~Workspace() {
if (*this) {
free();
}
}
operator bool() const {
return ptr_ != nullptr;
}
T * get() {
assert(*this);
return ptr_;
}
T const * get() const {
assert(*this);
return ptr_;
}
void free() {
assert(*this);
device_->FreeWorkspace(ctx_, ptr_);
ptr_ = nullptr;
}
private:
DeviceAPI* device_;
DGLContext ctx_;
T * ptr_;
};
template<>
class Workspace<void> {
public:
Workspace(DeviceAPI* device, DGLContext ctx, const size_t size) :
device_(device),
ctx_(ctx),
ptr_(static_cast<void*>(device_->AllocWorkspace(ctx_, size))) {
}
~Workspace() {
if (*this) {
free();
}
}
operator bool() const {
return ptr_ != nullptr;
}
void * get() {
assert(*this);
return ptr_;
}
void const * get() const {
assert(*this);
return ptr_;
}
void free() {
assert(*this);
device_->FreeWorkspace(ctx_, ptr_);
ptr_ = nullptr;
}
private:
DeviceAPI* device_;
DGLContext ctx_;
void * ptr_;
};
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_WORKSPACE_H_
from dgl.cuda import nccl
from dgl.partition import NDArrayPartition
import unittest
import backend as F
def gen_test_id():
return '{:0256x}'.format(78236728318467363)
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_id():
nccl_id = nccl.UniqueId()
text = str(nccl_id)
nccl_id2 = nccl.UniqueId(id_str=text)
assert nccl_id == nccl_id2
nccl_id2 = nccl.UniqueId(gen_test_id())
assert nccl_id2 != nccl_id
nccl_id3 = nccl.UniqueId(str(nccl_id2))
assert nccl_id2 == nccl_id3
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_push_single():
nccl_id = nccl.UniqueId()
comm = nccl.Communicator(1, 0, nccl_id)
index = F.randint([10000], F.int32, F.ctx(), 0, 10000)
value = F.uniform([10000, 100], F.float32, F.ctx(), -1.0, 1.0)
part = NDArrayPartition(10000, 1, 'remainder')
ri, rv = comm.sparse_all_to_all_push(index, value, part)
assert F.array_equal(ri, index)
assert F.array_equal(rv, value)
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_pull_single():
nccl_id = nccl.UniqueId()
comm = nccl.Communicator(1, 0, nccl_id)
req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000)
value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0)
part = NDArrayPartition(100000, 1, 'remainder')
rv = comm.sparse_all_to_all_pull(req_index, value, part)
exp_rv = F.gather_row(value, req_index)
assert F.array_equal(rv, exp_rv)
if __name__ == '__main__':
test_nccl_id()
test_nccl_sparse_push_single()
test_nccl_sparse_pull_single()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment