Commit b06330d0 authored by ThomasNing's avatar ThomasNing
Browse files

Polish the setup Connection part from Nusrat's comment

parent 9b685226
...@@ -34,8 +34,9 @@ template <class T> ...@@ -34,8 +34,9 @@ template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>; using DeviceHandle = mscclpp::DeviceHandle<T>;
extern __constant__ DeviceHandle<mscclpp::SmChannel> constSlaveSmChannels[8]; // For SmChannel extern __constant__ DeviceHandle<mscclpp::SmChannel> constSlaveSmChannels[8]; // For SmChannel
void setupConnection( extern __constant__ DeviceHandle<mscclpp::SmChannel> constMasterSmChannel;
int rank, int slaveRank, int worldSize, void* src_data, void* dst_data, size_t dataSize)
void setupConnection(int rank, int slaveRank, int worldSize, void* dst_data, size_t dataSize)
{ {
// Initialize MSCCL++ Communicator // Initialize MSCCL++ Communicator
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, worldSize); auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, worldSize);
...@@ -43,40 +44,67 @@ void setupConnection( ...@@ -43,40 +44,67 @@ void setupConnection(
mscclpp::Communicator comm(bootstrap); mscclpp::Communicator comm(bootstrap);
mscclpp::Transport transport = mscclpp::Transport::CudaIpc; mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
// We'll register our local memory. For the slave, this might be the destination buffer.
// For senders, this might be the source buffer or a local buffer we expose to the slave.
mscclpp::RegisteredMemory localMemory = comm.registerMemory(dst_data, dataSize, transport);
if(rank == slaveRank) if(rank == slaveRank)
{ {
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories; std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>>
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections( connectionFutures;
worldSize); std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemFutures;
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> slave_semaphore_list( std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> slave_semaphore_list(
worldSize); worldSize);
for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank) for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank)
{ {
if(senderRank == static_cast<size_t>(rank)) if(senderRank == static_cast<size_t>(rank))
continue; continue;
connections[senderRank] = comm.connectOnSetup(senderRank, 0, transport); connectionFutures.push_back(comm.connectOnSetup(senderRank, 0, transport));
remoteMemories.push_back(comm.recvMemoryOnSetup(senderRank, 0)); comm.sendMemoryOnSetup(localMemory, senderRank, 0);
remoteMemFutures.push_back(comm.recvMemoryOnSetup(senderRank, 0));
} }
comm.setup(); comm.setup();
for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank) // Now retrieve all completed futures
std::vector<std::shared_ptr<mscclpp::Connection>> connections;
connections.reserve(connectionFutures.size());
for(auto& cf : connectionFutures)
{ {
if(senderRank == static_cast<size_t>(rank)) connections.push_back(cf.get());
continue; }
auto connection = connections[senderRank].get();
slave_semaphore_list[senderRank] = std::vector<mscclpp::RegisteredMemory> remoteMemories;
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection); remoteMemories.reserve(remoteMemFutures.size());
for(auto& rmf : remoteMemFutures)
{
remoteMemories.push_back(rmf.get());
} }
// Create semaphores and channels
// One semaphore per connection
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> slaveSemaphores;
slaveSemaphores.reserve(connections.size());
for(auto& conn : connections)
{
slaveSemaphores.push_back(
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, conn));
}
// Create channels
std::vector<DeviceHandle<mscclpp::SmChannel>> SmChannels; std::vector<DeviceHandle<mscclpp::SmChannel>> SmChannels;
for(size_t i = 0; i < slave_semaphore_list.size(); ++i) SmChannels.reserve(slaveSemaphores.size());
for(size_t i = 0; i < slaveSemaphores.size(); ++i)
{ {
SmChannels.push_back(mscclpp::deviceHandle( SmChannels.push_back(mscclpp::deviceHandle(
mscclpp::SmChannel(slave_semaphore_list[i], remoteMemories[i].get(), src_data))); mscclpp::SmChannel(slaveSemaphores[i],
remoteMemories[i], // Remote buffer from the sender
dst_data // Local buffer (this slave's buffer)
)));
} }
hipError_t error = hipError_t error_slave =
hipMemcpyToSymbol(constSlaveSmChannels, hipMemcpyToSymbol(constSlaveSmChannels,
SmChannels.data(), SmChannels.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * SmChannels.size()); sizeof(DeviceHandle<mscclpp::SmChannel>) * SmChannels.size());
if(error != hipSuccess) if(error_slave != hipSuccess)
{ {
std::cerr << "Error locating data to constant memory" << std::endl; std::cerr << "Error locating data to constant memory" << std::endl;
return; return;
...@@ -84,15 +112,34 @@ void setupConnection( ...@@ -84,15 +112,34 @@ void setupConnection(
} }
else else
{ {
auto localMemory = comm.registerMemory(dst_data, dataSize, transport); // This is a sender:
mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>> connection = // We only connect to the slave, send our memory handle, and receive the slave's memory
// handle.
mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>> connectionFuture =
comm.connectOnSetup(slaveRank, 0, transport); comm.connectOnSetup(slaveRank, 0, transport);
// Send our memory to the slave
comm.sendMemoryOnSetup(localMemory, slaveRank, 0); comm.sendMemoryOnSetup(localMemory, slaveRank, 0);
// Receive slave's memory
mscclpp::NonblockingFuture<mscclpp::RegisteredMemory> remoteMemoryFuture =
comm.recvMemoryOnSetup(slaveRank, 0);
comm.setup(); comm.setup();
auto sender_semaphore = std::shared_ptr<mscclpp::Connection> connection = connectionFuture.get();
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection.get()); mscclpp::RegisteredMemory remoteMemory = remoteMemoryFuture.get();
auto tempSmChannel = mscclpp::SmChannel(sender_semaphore, localMemory, src_data);
DeviceHandle<mscclpp::SmChannel> SenderSmChannel = mscclpp::deviceHandle(tempSmChannel); auto senderSemaphore =
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection);
auto senderChannel = mscclpp::SmChannel(senderSemaphore, localMemory, remoteMemory.data());
DeviceHandle<mscclpp::SmChannel> senderSmChannel = mscclpp::deviceHandle(senderChannel);
hipError_t error_master = hipMemcpyToSymbol(
constMasterSmChannel, &senderSmChannel, sizeof(DeviceHandle<mscclpp::SmChannel>));
if(error_master != hipSuccess)
{
std::cerr << "Error locating data to constant memory" << std::endl;
return;
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment