Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c560040f
Unverified
Commit
c560040f
authored
Oct 18, 2021
by
nv-dlasalle
Committed by
GitHub
Oct 18, 2021
Browse files
[Fix] Split nccl sparse push into two groups (#3404)
parent
aa11aaa4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
13 deletions
+30
-13
src/runtime/cuda/nccl_api.cu
src/runtime/cuda/nccl_api.cu
+30
-13
No files found.
src/runtime/cuda/nccl_api.cu
View file @
c560040f
/*!
/*!
* Copyright (c) 2021 by Contributors
* Copyright (c) 2021 by Contributors
* \file nccl_api.cc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file nccl_api.cu
* \brief Implementation of wrapper around NCCL routines.
* \brief Implementation of wrapper around NCCL routines.
*/
*/
#ifdef DGL_USE_NCCL
#ifdef DGL_USE_NCCL
#include "nccl_api.h"
#include "nccl_api.h"
...
@@ -627,12 +641,12 @@ void NCCLCommunicator::AllToAll(
...
@@ -627,12 +641,12 @@ void NCCLCommunicator::AllToAll(
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
const
ncclDataType_t
type
=
NCCLType
<
IdType
>
();
const
ncclDataType_t
type
=
NCCLType
<
IdType
>
();
ncclGroupStart
();
NCCL_CALL
(
ncclGroupStart
()
)
;
for
(
int
r
=
0
;
r
<
size_
;
++
r
)
{
for
(
int
r
=
0
;
r
<
size_
;
++
r
)
{
ncclSend
(
send
+
(
r
*
count
),
count
,
type
,
r
,
comm_
,
stream
);
NCCL_CALL
(
ncclSend
(
send
+
(
r
*
count
),
count
,
type
,
r
,
comm_
,
stream
)
)
;
ncclRecv
(
recv
+
(
r
*
count
),
count
,
type
,
r
,
comm_
,
stream
);
NCCL_CALL
(
ncclRecv
(
recv
+
(
r
*
count
),
count
,
type
,
r
,
comm_
,
stream
)
)
;
}
}
ncclGroupEnd
();
NCCL_CALL
(
ncclGroupEnd
()
)
;
}
}
template
template
...
@@ -662,24 +676,27 @@ void NCCLCommunicator::SparseAllToAll(
...
@@ -662,24 +676,27 @@ void NCCLCommunicator::SparseAllToAll(
const
ncclDataType_t
idx_type
=
NCCLType
<
IdType
>
();
const
ncclDataType_t
idx_type
=
NCCLType
<
IdType
>
();
const
ncclDataType_t
value_type
=
NCCLType
<
DType
>
();
const
ncclDataType_t
value_type
=
NCCLType
<
DType
>
();
ncclGroupStart
();
// idxs
AllToAllV
(
send_idx
,
send_prefix
,
recv_idx
,
recv_prefix
,
stream
);
// values
NCCL_CALL
(
ncclGroupStart
());
for
(
int
r
=
0
;
r
<
size_
;
++
r
)
{
for
(
int
r
=
0
;
r
<
size_
;
++
r
)
{
const
int64_t
send_size
=
send_prefix
[
r
+
1
]
-
send_prefix
[
r
];
const
int64_t
send_size
=
send_prefix
[
r
+
1
]
-
send_prefix
[
r
];
if
(
send_size
>
0
)
{
if
(
send_size
>
0
)
{
ncclSend
(
send_idx
+
send_prefix
[
r
],
send_size
,
idx_type
,
r
,
comm_
,
stream
);
NCCL_CALL
(
ncclSend
(
send_value
+
send_prefix
[
r
]
*
num_feat
,
send_size
*
num_feat
,
ncclSend
(
send_value
+
send_prefix
[
r
]
*
num_feat
,
send_size
*
num_feat
,
value_type
,
r
,
comm_
,
stream
));
value_type
,
r
,
comm_
,
stream
);
}
}
const
int64_t
recv_size
=
recv_prefix
[
r
+
1
]
-
recv_prefix
[
r
];
const
int64_t
recv_size
=
recv_prefix
[
r
+
1
]
-
recv_prefix
[
r
];
if
(
recv_size
>
0
)
{
if
(
recv_size
>
0
)
{
ncclRecv
(
recv_idx
+
recv_prefix
[
r
],
recv_size
,
idx_type
,
r
,
comm_
,
stream
);
NCCL_CALL
(
ncclRecv
(
recv_value
+
recv_prefix
[
r
]
*
num_feat
,
recv_size
*
num_feat
,
ncclRecv
(
recv_value
+
recv_prefix
[
r
]
*
num_feat
,
recv_size
*
num_feat
,
value_type
,
r
,
comm_
,
stream
));
value_type
,
r
,
comm_
,
stream
);
}
}
}
}
ncclGroupEnd
();
NCCL_CALL
(
ncclGroupEnd
()
)
;
}
}
template
template
void
NCCLCommunicator
::
SparseAllToAll
<
int32_t
,
__half
>(
void
NCCLCommunicator
::
SparseAllToAll
<
int32_t
,
__half
>(
const
int32_t
*
const
send_idx
,
const
int32_t
*
const
send_idx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment