Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b06330d0
Commit
b06330d0
authored
Dec 14, 2024
by
ThomasNing
Browse files
Polish the setup Connection part from Nusrat's comment
parent
9b685226
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
23 deletions
+70
-23
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
+70
-23
No files found.
example/ck_tile/15_cross_gpu_reduce/cross_gpu_reduce.cpp
View file @
b06330d0
...
@@ -34,8 +34,9 @@ template <class T>
...
@@ -34,8 +34,9 @@ template <class T>
using
DeviceHandle
=
mscclpp
::
DeviceHandle
<
T
>
;
using
DeviceHandle
=
mscclpp
::
DeviceHandle
<
T
>
;
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constSlaveSmChannels
[
8
];
// For SmChannel
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constSlaveSmChannels
[
8
];
// For SmChannel
void
setupConnection
(
extern
__constant__
DeviceHandle
<
mscclpp
::
SmChannel
>
constMasterSmChannel
;
int
rank
,
int
slaveRank
,
int
worldSize
,
void
*
src_data
,
void
*
dst_data
,
size_t
dataSize
)
void
setupConnection
(
int
rank
,
int
slaveRank
,
int
worldSize
,
void
*
dst_data
,
size_t
dataSize
)
{
{
// Initialize MSCCL++ Communicator
// Initialize MSCCL++ Communicator
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
TcpBootstrap
>
(
rank
,
worldSize
);
auto
bootstrap
=
std
::
make_shared
<
mscclpp
::
TcpBootstrap
>
(
rank
,
worldSize
);
...
@@ -43,40 +44,67 @@ void setupConnection(
...
@@ -43,40 +44,67 @@ void setupConnection(
mscclpp
::
Communicator
comm
(
bootstrap
);
mscclpp
::
Communicator
comm
(
bootstrap
);
mscclpp
::
Transport
transport
=
mscclpp
::
Transport
::
CudaIpc
;
mscclpp
::
Transport
transport
=
mscclpp
::
Transport
::
CudaIpc
;
// We'll register our local memory. For the slave, this might be the destination buffer.
// For senders, this might be the source buffer or a local buffer we expose to the slave.
mscclpp
::
RegisteredMemory
localMemory
=
comm
.
registerMemory
(
dst_data
,
dataSize
,
transport
);
if
(
rank
==
slaveRank
)
if
(
rank
==
slaveRank
)
{
{
std
::
vector
<
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>>
remoteMemories
;
std
::
vector
<
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
std
::
vector
<
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>>
connections
(
connectionFutures
;
worldSize
)
;
std
::
vector
<
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>>
remoteMemFutures
;
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slave_semaphore_list
(
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slave_semaphore_list
(
worldSize
);
worldSize
);
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
{
{
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
continue
;
continue
;
connections
[
senderRank
]
=
comm
.
connectOnSetup
(
senderRank
,
0
,
transport
);
connectionFutures
.
push_back
(
comm
.
connectOnSetup
(
senderRank
,
0
,
transport
));
remoteMemories
.
push_back
(
comm
.
recvMemoryOnSetup
(
senderRank
,
0
));
comm
.
sendMemoryOnSetup
(
localMemory
,
senderRank
,
0
);
remoteMemFutures
.
push_back
(
comm
.
recvMemoryOnSetup
(
senderRank
,
0
));
}
}
comm
.
setup
();
comm
.
setup
();
for
(
size_t
senderRank
=
0
;
senderRank
<
static_cast
<
size_t
>
(
worldSize
);
++
senderRank
)
// Now retrieve all completed futures
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connections
;
connections
.
reserve
(
connectionFutures
.
size
());
for
(
auto
&
cf
:
connectionFutures
)
{
{
if
(
senderRank
==
static_cast
<
size_t
>
(
rank
))
connections
.
push_back
(
cf
.
get
());
continue
;
auto
connection
=
connections
[
senderRank
].
get
();
slave_semaphore_list
[
senderRank
]
=
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
);
}
}
std
::
vector
<
mscclpp
::
RegisteredMemory
>
remoteMemories
;
remoteMemories
.
reserve
(
remoteMemFutures
.
size
());
for
(
auto
&
rmf
:
remoteMemFutures
)
{
remoteMemories
.
push_back
(
rmf
.
get
());
}
// Create semaphores and channels
// One semaphore per connection
std
::
vector
<
std
::
shared_ptr
<
mscclpp
::
SmDevice2DeviceSemaphore
>>
slaveSemaphores
;
slaveSemaphores
.
reserve
(
connections
.
size
());
for
(
auto
&
conn
:
connections
)
{
slaveSemaphores
.
push_back
(
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
conn
));
}
// Create channels
std
::
vector
<
DeviceHandle
<
mscclpp
::
SmChannel
>>
SmChannels
;
std
::
vector
<
DeviceHandle
<
mscclpp
::
SmChannel
>>
SmChannels
;
for
(
size_t
i
=
0
;
i
<
slave_semaphore_list
.
size
();
++
i
)
SmChannels
.
reserve
(
slaveSemaphores
.
size
());
for
(
size_t
i
=
0
;
i
<
slaveSemaphores
.
size
();
++
i
)
{
{
SmChannels
.
push_back
(
mscclpp
::
deviceHandle
(
SmChannels
.
push_back
(
mscclpp
::
deviceHandle
(
mscclpp
::
SmChannel
(
slave_semaphore_list
[
i
],
remoteMemories
[
i
].
get
(),
src_data
)));
mscclpp
::
SmChannel
(
slaveSemaphores
[
i
],
remoteMemories
[
i
],
// Remote buffer from the sender
dst_data
// Local buffer (this slave's buffer)
)));
}
}
hipError_t
error
=
hipError_t
error
_slave
=
hipMemcpyToSymbol
(
constSlaveSmChannels
,
hipMemcpyToSymbol
(
constSlaveSmChannels
,
SmChannels
.
data
(),
SmChannels
.
data
(),
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
)
*
SmChannels
.
size
());
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
)
*
SmChannels
.
size
());
if
(
error
!=
hipSuccess
)
if
(
error
_slave
!=
hipSuccess
)
{
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
return
;
...
@@ -84,15 +112,34 @@ void setupConnection(
...
@@ -84,15 +112,34 @@ void setupConnection(
}
}
else
else
{
{
auto
localMemory
=
comm
.
registerMemory
(
dst_data
,
dataSize
,
transport
);
// This is a sender:
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connection
=
// We only connect to the slave, send our memory handle, and receive the slave's memory
// handle.
mscclpp
::
NonblockingFuture
<
std
::
shared_ptr
<
mscclpp
::
Connection
>>
connectionFuture
=
comm
.
connectOnSetup
(
slaveRank
,
0
,
transport
);
comm
.
connectOnSetup
(
slaveRank
,
0
,
transport
);
// Send our memory to the slave
comm
.
sendMemoryOnSetup
(
localMemory
,
slaveRank
,
0
);
comm
.
sendMemoryOnSetup
(
localMemory
,
slaveRank
,
0
);
// Receive slave's memory
mscclpp
::
NonblockingFuture
<
mscclpp
::
RegisteredMemory
>
remoteMemoryFuture
=
comm
.
recvMemoryOnSetup
(
slaveRank
,
0
);
comm
.
setup
();
comm
.
setup
();
auto
sender_semaphore
=
std
::
shared_ptr
<
mscclpp
::
Connection
>
connection
=
connectionFuture
.
get
();
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
.
get
());
mscclpp
::
RegisteredMemory
remoteMemory
=
remoteMemoryFuture
.
get
();
auto
tempSmChannel
=
mscclpp
::
SmChannel
(
sender_semaphore
,
localMemory
,
src_data
);
DeviceHandle
<
mscclpp
::
SmChannel
>
SenderSmChannel
=
mscclpp
::
deviceHandle
(
tempSmChannel
);
auto
senderSemaphore
=
std
::
make_shared
<
mscclpp
::
SmDevice2DeviceSemaphore
>
(
comm
,
connection
);
auto
senderChannel
=
mscclpp
::
SmChannel
(
senderSemaphore
,
localMemory
,
remoteMemory
.
data
());
DeviceHandle
<
mscclpp
::
SmChannel
>
senderSmChannel
=
mscclpp
::
deviceHandle
(
senderChannel
);
hipError_t
error_master
=
hipMemcpyToSymbol
(
constMasterSmChannel
,
&
senderSmChannel
,
sizeof
(
DeviceHandle
<
mscclpp
::
SmChannel
>
));
if
(
error_master
!=
hipSuccess
)
{
std
::
cerr
<<
"Error locating data to constant memory"
<<
std
::
endl
;
return
;
}
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment