Unverified Commit c560040f authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Fix] Split nccl sparse push into two groups (#3404)

parent aa11aaa4
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* \file nccl_api.cc *
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file nccl_api.cu
* \brief Implementation of wrapper around NCCL routines. * \brief Implementation of wrapper around NCCL routines.
*/ */
#ifdef DGL_USE_NCCL #ifdef DGL_USE_NCCL
#include "nccl_api.h" #include "nccl_api.h"
...@@ -627,12 +641,12 @@ void NCCLCommunicator::AllToAll( ...@@ -627,12 +641,12 @@ void NCCLCommunicator::AllToAll(
cudaStream_t stream) { cudaStream_t stream) {
const ncclDataType_t type = NCCLType<IdType>(); const ncclDataType_t type = NCCLType<IdType>();
ncclGroupStart(); NCCL_CALL(ncclGroupStart());
for (int r = 0; r < size_; ++r) { for (int r = 0; r < size_; ++r) {
ncclSend(send+(r*count), count, type, r, comm_, stream); NCCL_CALL(ncclSend(send+(r*count), count, type, r, comm_, stream));
ncclRecv(recv+(r*count), count, type, r, comm_, stream); NCCL_CALL(ncclRecv(recv+(r*count), count, type, r, comm_, stream));
} }
ncclGroupEnd(); NCCL_CALL(ncclGroupEnd());
} }
template template
...@@ -662,24 +676,27 @@ void NCCLCommunicator::SparseAllToAll( ...@@ -662,24 +676,27 @@ void NCCLCommunicator::SparseAllToAll(
const ncclDataType_t idx_type = NCCLType<IdType>(); const ncclDataType_t idx_type = NCCLType<IdType>();
const ncclDataType_t value_type = NCCLType<DType>(); const ncclDataType_t value_type = NCCLType<DType>();
ncclGroupStart(); // idxs
AllToAllV(send_idx, send_prefix, recv_idx, recv_prefix, stream);
// values
NCCL_CALL(ncclGroupStart());
for (int r = 0; r < size_; ++r) { for (int r = 0; r < size_; ++r) {
const int64_t send_size = send_prefix[r+1]-send_prefix[r]; const int64_t send_size = send_prefix[r+1]-send_prefix[r];
if (send_size > 0) { if (send_size > 0) {
ncclSend(send_idx+send_prefix[r], send_size, idx_type, r, comm_, stream); NCCL_CALL(ncclSend(send_value+send_prefix[r]*num_feat, send_size*num_feat,
ncclSend(send_value+send_prefix[r]*num_feat, send_size*num_feat, value_type, r, comm_, stream));
value_type, r, comm_, stream);
} }
const int64_t recv_size = recv_prefix[r+1]-recv_prefix[r]; const int64_t recv_size = recv_prefix[r+1]-recv_prefix[r];
if (recv_size > 0) { if (recv_size > 0) {
ncclRecv(recv_idx+recv_prefix[r], recv_size, idx_type, r, comm_, stream); NCCL_CALL(ncclRecv(recv_value+recv_prefix[r]*num_feat, recv_size*num_feat,
ncclRecv(recv_value+recv_prefix[r]*num_feat, recv_size*num_feat, value_type, r, comm_, stream));
value_type, r, comm_, stream);
} }
} }
ncclGroupEnd(); NCCL_CALL(ncclGroupEnd());
} }
template template
void NCCLCommunicator::SparseAllToAll<int32_t, __half>( void NCCLCommunicator::SparseAllToAll<int32_t, __half>(
const int32_t * const send_idx, const int32_t * const send_idx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment