You need to sign in or sign up before continuing.
Unverified Commit 1dddaad4 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Allow communicators of size one when NCCL is missing (#3713)



* Update nccl communicator for when NCCL is missing

* Use static_cast

* Add doc string

* Fix whitespace

* Resrtict unit test to GPU runs
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent a25a14f2
...@@ -210,4 +210,14 @@ class Communicator(object): ...@@ -210,4 +210,14 @@ class Communicator(object):
""" """
return self._size return self._size
def is_supported():
""" Check if DGL was built with NCCL support.
Returns
-------
bool
True if NCCL support was built in.
"""
return _CAPI_DGLNCCLHasSupport()
_init_api("dgl.cuda.nccl") _init_api("dgl.cuda.nccl")
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021-2022 by Contributors
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
#ifdef DGL_USE_NCCL
#include "nccl_api.h" #include "nccl_api.h"
#include <dgl/array.h> #include <dgl/array.h>
...@@ -38,6 +36,7 @@ ...@@ -38,6 +36,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <algorithm>
#include <limits> #include <limits>
#include "cuda_common.h" #include "cuda_common.h"
...@@ -64,10 +63,7 @@ namespace cuda { ...@@ -64,10 +63,7 @@ namespace cuda {
namespace { namespace {
enum class AllToAllMode : int { #ifdef DGL_USE_NCCL
REMAINDER = 0
};
template<typename T> ncclDataType_t NCCLType(); template<typename T> ncclDataType_t NCCLType();
template<> ncclDataType_t NCCLType<int32_t>() { template<> ncclDataType_t NCCLType<int32_t>() {
...@@ -86,6 +82,7 @@ template<> ncclDataType_t NCCLType<double>() { ...@@ -86,6 +82,7 @@ template<> ncclDataType_t NCCLType<double>() {
return ncclFloat64; return ncclFloat64;
} }
#endif // DGL_USE_NCCL
template<typename IdType, typename DType> template<typename IdType, typename DType>
__global__ void _DualPermKernel( __global__ void _DualPermKernel(
...@@ -517,8 +514,14 @@ NDArray SparsePull( ...@@ -517,8 +514,14 @@ NDArray SparsePull(
NCCLUniqueId::NCCLUniqueId() : NCCLUniqueId::NCCLUniqueId() :
id_() { id_() {
#ifdef DGL_USE_NCCL
// this ID is unique to the process, not to each call of this function // this ID is unique to the process, not to each call of this function
NCCL_CALL(ncclGetUniqueId(&id_)); NCCL_CALL(ncclGetUniqueId(&id_));
#else
// when NCCL isn't enabled, use all zeros
std::fill(id_.internal, id_.internal + NCCL_UNIQUE_ID_BYTES,
static_cast<char>(0));
#endif
} }
ncclUniqueId NCCLUniqueId::Get() const { ncclUniqueId NCCLUniqueId::Get() const {
...@@ -554,7 +557,6 @@ void NCCLUniqueId::FromString( ...@@ -554,7 +557,6 @@ void NCCLUniqueId::FromString(
} }
/* NCCLCommunicator **********************************************************/ /* NCCLCommunicator **********************************************************/
NCCLCommunicator::NCCLCommunicator( NCCLCommunicator::NCCLCommunicator(
...@@ -569,11 +571,19 @@ NCCLCommunicator::NCCLCommunicator( ...@@ -569,11 +571,19 @@ NCCLCommunicator::NCCLCommunicator(
CHECK_GE(rank, 0) << "The rank (" << rank << ") must be greater than or " CHECK_GE(rank, 0) << "The rank (" << rank << ") must be greater than or "
"equal to 0."; "equal to 0.";
#ifdef DGL_USE_NCCL
NCCL_CALL(ncclCommInitRank(&comm_, size_, id, rank_)); NCCL_CALL(ncclCommInitRank(&comm_, size_, id, rank_));
#else
CHECK_EQ(size, 1) << "Cannot create a communicator of size " << size << ". "
"To use a communicator size greater than 1, compile DGL with NCCL "
"support.";
#endif
} }
NCCLCommunicator::~NCCLCommunicator() { NCCLCommunicator::~NCCLCommunicator() {
#ifdef DGL_USE_NCCL
ncclCommDestroy(comm_); ncclCommDestroy(comm_);
#endif
} }
ncclComm_t NCCLCommunicator::Get() { ncclComm_t NCCLCommunicator::Get() {
...@@ -587,6 +597,7 @@ void NCCLCommunicator::AllToAllV( ...@@ -587,6 +597,7 @@ void NCCLCommunicator::AllToAllV(
DType * const recv, DType * const recv,
const int64_t * const recv_prefix, const int64_t * const recv_prefix,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef DGL_USE_NCCL
const ncclDataType_t type = NCCLType<DType>(); const ncclDataType_t type = NCCLType<DType>();
NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclGroupStart());
...@@ -601,6 +612,24 @@ void NCCLCommunicator::AllToAllV( ...@@ -601,6 +612,24 @@ void NCCLCommunicator::AllToAllV(
} }
} }
NCCL_CALL(ncclGroupEnd()); NCCL_CALL(ncclGroupEnd());
#else
CHECK_EQ(send_prefix[1]-send_prefix[0], recv_prefix[1]-recv_prefix[0]) <<
"Send message size must equal receive message size.";
int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id};
auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<DType>::dtype;
device->CopyDataFromTo(send, send_prefix[0],
recv, recv_prefix[0],
sizeof(DType)*send_prefix[1]-send_prefix[0],
ctx, ctx,
dtype,
stream);
#endif
} }
template template
...@@ -639,6 +668,7 @@ void NCCLCommunicator::AllToAll( ...@@ -639,6 +668,7 @@ void NCCLCommunicator::AllToAll(
IdType * const recv, IdType * const recv,
const int64_t count, const int64_t count,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef DGL_USE_NCCL
const ncclDataType_t type = NCCLType<IdType>(); const ncclDataType_t type = NCCLType<IdType>();
NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclGroupStart());
...@@ -647,6 +677,16 @@ void NCCLCommunicator::AllToAll( ...@@ -647,6 +677,16 @@ void NCCLCommunicator::AllToAll(
NCCL_CALL(ncclRecv(recv+(r*count), count, type, r, comm_, stream)); NCCL_CALL(ncclRecv(recv+(r*count), count, type, r, comm_, stream));
} }
NCCL_CALL(ncclGroupEnd()); NCCL_CALL(ncclGroupEnd());
#else
int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id));
DGLContext ctx{kDLGPU, dev_id};
auto device = runtime::DeviceAPI::Get(ctx);
auto dtype = DLDataTypeTraits<IdType>::dtype;
device->CopyDataFromTo(send, 0, recv, 0, count, ctx, ctx, dtype, stream);
#endif
} }
template template
...@@ -673,27 +713,20 @@ void NCCLCommunicator::SparseAllToAll( ...@@ -673,27 +713,20 @@ void NCCLCommunicator::SparseAllToAll(
DType * const recv_value, DType * const recv_value,
const int64_t * const recv_prefix, const int64_t * const recv_prefix,
cudaStream_t stream) { cudaStream_t stream) {
const ncclDataType_t idx_type = NCCLType<IdType>();
const ncclDataType_t value_type = NCCLType<DType>();
// idxs // idxs
AllToAllV(send_idx, send_prefix, recv_idx, recv_prefix, stream); AllToAllV(send_idx, send_prefix, recv_idx, recv_prefix, stream);
// values // scale prefixes by number of features
NCCL_CALL(ncclGroupStart()); std::vector<int64_t> value_send_prefix(size_+1);
for (int r = 0; r < size_; ++r) { for (int r = 0; r < size_+1; ++r) {
const int64_t send_size = send_prefix[r+1]-send_prefix[r]; value_send_prefix[r] = send_prefix[r]*num_feat;
if (send_size > 0) {
NCCL_CALL(ncclSend(send_value+send_prefix[r]*num_feat, send_size*num_feat,
value_type, r, comm_, stream));
}
const int64_t recv_size = recv_prefix[r+1]-recv_prefix[r];
if (recv_size > 0) {
NCCL_CALL(ncclRecv(recv_value+recv_prefix[r]*num_feat, recv_size*num_feat,
value_type, r, comm_, stream));
}
} }
NCCL_CALL(ncclGroupEnd()); std::vector<int64_t> value_recv_prefix(size_+1);
for (int r = 0; r < size_+1; ++r) {
value_recv_prefix[r] = recv_prefix[r]*num_feat;
}
AllToAllV(send_value, value_send_prefix.data(),
recv_value, value_recv_prefix.data(), stream);
} }
...@@ -795,10 +828,15 @@ DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPull") ...@@ -795,10 +828,15 @@ DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLSparseAllToAllPull")
}); });
}); });
DGL_REGISTER_GLOBAL("cuda.nccl._CAPI_DGLNCCLHasSupport")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
#ifndef DGL_USE_NCCL
return false;
#else
return true;
#endif
});
} // namespace cuda } // namespace cuda
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
#endif
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021-2022 by Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file nccl_api.h * \file nccl_api.h
* \brief Wrapper around NCCL routines. * \brief Wrapper around NCCL routines.
*/ */
...@@ -8,7 +21,15 @@ ...@@ -8,7 +21,15 @@
#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_ #ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_ #define DGL_RUNTIME_CUDA_NCCL_API_H_
#ifdef DGL_USE_NCCL
#include "nccl.h" #include "nccl.h"
#else
// if not compiling with NCCL, this class will only support communicators of
// size 1.
#define NCCL_UNIQUE_ID_BYTES 128
typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId;
typedef int ncclComm_t;
#endif
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <string> #include <string>
......
...@@ -82,6 +82,11 @@ def test_nccl_sparse_pull_single_range(): ...@@ -82,6 +82,11 @@ def test_nccl_sparse_pull_single_range():
exp_rv = F.gather_row(value, req_index) exp_rv = F.gather_row(value, req_index)
assert F.array_equal(rv, exp_rv) assert F.array_equal(rv, exp_rv)
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_support():
# this is just a smoke test, as we don't have any other way to know
# if NCCL support is compiled in right now.
nccl.is_supported()
if __name__ == '__main__': if __name__ == '__main__':
test_nccl_id() test_nccl_id()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment