Commit 9b685226 authored by ThomasNing's avatar ThomasNing
Browse files

Finished the poc for MSCCLPP

parent cd71c0a0
......@@ -17,10 +17,12 @@
#pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override"
#pragma clang diagnostic ignored "-Wcast-align"
#pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wdeprecated-copy-with-user-provided-dtor"
#include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/semaphore.hpp>
#pragma clang diagnostic pop
......@@ -30,27 +32,67 @@
template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::SmChannel> constSmChannels[8]; // For SmChannel
extern __constant__ DeviceHandle<mscclpp::SmChannel> constSlaveSmChannels[8]; // For SmChannel
void setupConnection(int rank, int worldSize, void* data, size_t dataSize){
void setupConnection(
int rank, int slaveRank, int worldSize, void* src_data, void* dst_data, size_t dataSize)
{
// Initialize MSCCL++ Communicator
mscclpp::Transport transport = mscclpp::Transport::SmChannel;
// Create the communicator
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, worldSize);
mscclpp::Communicator comm(bootstrap);
// Allocate and register memory
auto localMemory = comm.registerMemory(data, dataSize, transport);
std::vector<mscclpp::RegisteredMemory> remoteMemories;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections;
if (rank == 0) {
for(int senderRank = 1; senderRank < worldSize; ++senderRank) {
connections[senderRank] = comm.connectOnSetup(senderRank, 0, mscclpp::Transport::SmChannel);
// Receive memory from sender
remoteMemories.push_back(comm.recvMemoryOnSetup(senderRank, 0));
mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
if(rank == slaveRank)
{
std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections(
worldSize);
std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> slave_semaphore_list(
worldSize);
for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank)
{
if(senderRank == static_cast<size_t>(rank))
continue;
connections[senderRank] = comm.connectOnSetup(senderRank, 0, transport);
remoteMemories.push_back(comm.recvMemoryOnSetup(senderRank, 0));
}
comm.setup();
for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank)
{
if(senderRank == static_cast<size_t>(rank))
continue;
auto connection = connections[senderRank].get();
slave_semaphore_list[senderRank] =
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection);
}
} else {
connections[0] = comm.connectOnSetup(0, 0, mscclpp::Transport::SmChannel);
std::vector<DeviceHandle<mscclpp::SmChannel>> SmChannels;
for(size_t i = 0; i < slave_semaphore_list.size(); ++i)
{
SmChannels.push_back(mscclpp::deviceHandle(
mscclpp::SmChannel(slave_semaphore_list[i], remoteMemories[i].get(), src_data)));
}
hipError_t error =
hipMemcpyToSymbol(constSlaveSmChannels,
SmChannels.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * SmChannels.size());
if(error != hipSuccess)
{
std::cerr << "Error locating data to constant memory" << std::endl;
return;
}
}
else
{
auto localMemory = comm.registerMemory(dst_data, dataSize, transport);
mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>> connection =
comm.connectOnSetup(slaveRank, 0, transport);
comm.sendMemoryOnSetup(localMemory, slaveRank, 0);
comm.setup();
auto sender_semaphore =
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection.get());
auto tempSmChannel = mscclpp::SmChannel(sender_semaphore, localMemory, src_data);
DeviceHandle<mscclpp::SmChannel> SenderSmChannel = mscclpp::deviceHandle(tempSmChannel);
}
}
......@@ -158,8 +200,8 @@ struct AllocateAndTransferFunctor
else
{
const void* send_location_ptr = host_receive_ptr_future.get();
args_send.p_send = send_location_ptr;
auto kargs_master = MasterKernel::MakeKargs(
args_send.p_send = send_location_ptr;
auto kargs_master = MasterKernel::MakeKargs(
args_send.p_reduce, args_send.p_send, args_send.M, args_send.N);
const dim3 grids_master = MasterKernel::GridSize(M, N);
ave_time = ck_tile::launch_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment