Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9b685226
Commit
9b685226
authored
Dec 13, 2024
by
ThomasNing
Browse files
Finished the poc for MSCCLPP
parent
cd71c0a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
19 deletions
+61
-19
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
+61
-19
No files found.
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
View file @
9b685226
...
@@ -17,10 +17,12 @@
...
@@ -17,10 +17,12 @@
#pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override"
#pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override"
#pragma clang diagnostic ignored "-Wcast-align"
#pragma clang diagnostic ignored "-Wcast-align"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wdeprecated-copy-with-user-provided-dtor"
#include <mscclpp/core.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/semaphore.hpp>
#pragma clang diagnostic pop
#pragma clang diagnostic pop
...
@@ -30,27 +32,67 @@
...
@@ -30,27 +32,67 @@
template
<
class
T
>
template
<
class
T
>
using
DeviceHandle
=
mscclpp
::
DeviceHandle
<
T
>
;
using
DeviceHandle
=
mscclpp
::
DeviceHandle
<
T
>
;
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constSmChannels
[
8
];
// For SmChannel
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constS
laveS
mChannels
[
8
];
// For SmChannel
void
setupConnection
(
void
setupConnection
(
int
rank
,
int
worldSize
,
void
*
data
,
size_t
dataSize
){
int
rank
,
int
slaveRank
,
int
worldSize
,
void
*
src_data
,
void
*
dst_data
,
size_t
dataSize
)
{
// Initialize MSCCL++ Communicator
// Initialize MSCCL++ Communicator
mscclpp
::
Transport
transport
=
mscclpp
::
Transport
::
SmChannel
;
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
TcpBootstrap
>
(
rank
,
worldSize
);
// Create the communicator
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
Bootstrap
>
(
rank
,
worldSize
);
mscclpp
::
Communicator
comm
(
bootstrap
);
mscclpp
::
Communicator
comm
(
bootstrap
);
// Allocate and register memory
mscclpp
::
Transport
transport
=
mscclpp
::
Transport
::
CudaIpc
;
auto
localMemory
=
comm
.
registerMemory
(
data
,
dataSize
,
transport
);
std
::
vector
<
mscclpp
::
RegisteredMemory
>
remoteMemories
;
if
(
rank
==
slaveRank
)
std
::
vector
<
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
connections
;
{
if
(
rank
==
0
)
{
std
::
vector
<
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>>
remoteMemories
;
for
(
int
senderRank
=
1
;
senderRank
<
worldSize
;
++
senderRank
)
{
std
::
vector
<
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
connections
(
connections
[
senderRank
]
=
comm
.
connectOnSetup
(
senderRank
,
0
,
mscclpp
::
Transport
::
SmChannel
);
worldSize
);
// Receive memory from sender
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slave_semaphore_list
(
remoteMemories
.
push_back
(
comm
.
recvMemoryOnSetup
(
senderRank
,
0
));
worldSize
);
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
{
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
continue
;
connections
[
senderRank
]
=
comm
.
connectOnSetup
(
senderRank
,
0
,
transport
);
remoteMemories
.
push_back
(
comm
.
recvMemoryOnSetup
(
senderRank
,
0
));
}
comm
.
setup
();
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
{
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
continue
;
auto
connection
=
connections
[
senderRank
].
get
();
slave_semaphore_list
[
senderRank
]
=
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
);
}
}
}
else
{
std
::
vector
<
DeviceHandle
<
mscclpp
::
SmChannel
>>
SmChannels
;
connections
[
0
]
=
comm
.
connectOnSetup
(
0
,
0
,
mscclpp
::
Transport
::
SmChannel
);
for
(
size_t
i
=
0
;
i
<
slave_semaphore_list
.
size
();
++
i
)
{
SmChannels
.
push_back
(
mscclpp
::
deviceHandle
(
mscclpp
::
SmChannel
(
slave_semaphore_list
[
i
],
remoteMemories
[
i
].
get
(),
src_data
)));
}
hipError_t
error
=
hipMemcpyToSymbol
(
constSlaveSmChannels
,
SmChannels
.
data
(),
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
)
*
SmChannels
.
size
());
if
(
error
!=
hipSuccess
)
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
}
}
else
{
auto
localMemory
=
comm
.
registerMemory
(
dst_data
,
dataSize
,
transport
);
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connection
=
comm
.
connectOnSetup
(
slaveRank
,
0
,
transport
);
comm
.
sendMemoryOnSetup
(
localMemory
,
slaveRank
,
0
);
comm
.
setup
();
auto
sender_semaphore
=
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
.
get
());
auto
tempSmChannel
=
mscclpp
::
SmChannel
(
sender_semaphore
,
localMemory
,
src_data
);
DeviceHandle
<
mscclpp
::
SmChannel
>
SenderSmChannel
=
mscclpp
::
deviceHandle
(
tempSmChannel
);
}
}
}
}
...
@@ -158,8 +200,8 @@ struct AllocateAndTransferFunctor
...
@@ -158,8 +200,8 @@ struct AllocateAndTransferFunctor
else
else
{
{
const
void
*
send_location_ptr
=
host_receive_ptr_future
.
get
();
const
void
*
send_location_ptr
=
host_receive_ptr_future
.
get
();
args_send
.
p_send
=
send_location_ptr
;
args_send
.
p_send
=
send_location_ptr
;
auto
kargs_master
=
MasterKernel
::
MakeKargs
(
auto
kargs_master
=
MasterKernel
::
MakeKargs
(
args_send
.
p_reduce
,
args_send
.
p_send
,
args_send
.
M
,
args_send
.
N
);
args_send
.
p_reduce
,
args_send
.
p_send
,
args_send
.
M
,
args_send
.
N
);
const
dim3
grids_master
=
MasterKernel
::
GridSize
(
M
,
N
);
const
dim3
grids_master
=
MasterKernel
::
GridSize
(
M
,
N
);
ave_time
=
ck_tile
::
launch_kernel
(
ave_time
=
ck_tile
::
launch_kernel
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment