Commit 9b685226 authored by ThomasNing's avatar ThomasNing
Browse files

Finished the poc for MSCCLPP

parent cd71c0a0
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
#pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override" #pragma clang diagnostic ignored "-Winconsistent-missing-destructor-override"
#pragma clang diagnostic ignored "-Wcast-align" #pragma clang diagnostic ignored "-Wcast-align"
#pragma clang diagnostic ignored "-Wglobal-constructors" #pragma clang diagnostic ignored "-Wglobal-constructors"
#pragma clang diagnostic ignored "-Wdeprecated-copy-with-user-provided-dtor"
#include <mscclpp/core.hpp> #include <mscclpp/core.hpp>
#include <mscclpp/gpu_utils.hpp> #include <mscclpp/gpu_utils.hpp>
#include <mscclpp/sm_channel.hpp> #include <mscclpp/sm_channel.hpp>
#include <mscclpp/semaphore.hpp>
#pragma clang diagnostic pop #pragma clang diagnostic pop
...@@ -30,27 +32,67 @@ ...@@ -30,27 +32,67 @@
template <class T> template <class T>
using DeviceHandle = mscclpp::DeviceHandle<T>; using DeviceHandle = mscclpp::DeviceHandle<T>;
__constant__ DeviceHandle<mscclpp::SmChannel> constSmChannels[8]; // For SmChannel extern __constant__ DeviceHandle<mscclpp::SmChannel> constSlaveSmChannels[8]; // For SmChannel
void setupConnection(
void setupConnection(int rank, int worldSize, void* data, size_t dataSize){ int rank, int slaveRank, int worldSize, void* src_data, void* dst_data, size_t dataSize)
{
// Initialize MSCCL++ Communicator // Initialize MSCCL++ Communicator
mscclpp::Transport transport = mscclpp::Transport::SmChannel; auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, worldSize);
// Create the communicator
auto bootstrap = std::make_shared<mscclpp::Bootstrap>(rank, worldSize);
mscclpp::Communicator comm(bootstrap); mscclpp::Communicator comm(bootstrap);
// Allocate and register memory mscclpp::Transport transport = mscclpp::Transport::CudaIpc;
auto localMemory = comm.registerMemory(data, dataSize, transport);
std::vector<mscclpp::RegisteredMemory> remoteMemories; if(rank == slaveRank)
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections; {
if (rank == 0) { std::vector<mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remoteMemories;
for(int senderRank = 1; senderRank < worldSize; ++senderRank) { std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connections(
connections[senderRank] = comm.connectOnSetup(senderRank, 0, mscclpp::Transport::SmChannel); worldSize);
// Receive memory from sender std::vector<std::shared_ptr<mscclpp::SmDevice2DeviceSemaphore>> slave_semaphore_list(
remoteMemories.push_back(comm.recvMemoryOnSetup(senderRank, 0)); worldSize);
for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank)
{
if(senderRank == static_cast<size_t>(rank))
continue;
connections[senderRank] = comm.connectOnSetup(senderRank, 0, transport);
remoteMemories.push_back(comm.recvMemoryOnSetup(senderRank, 0));
}
comm.setup();
for(size_t senderRank = 0; senderRank < static_cast<size_t>(worldSize); ++senderRank)
{
if(senderRank == static_cast<size_t>(rank))
continue;
auto connection = connections[senderRank].get();
slave_semaphore_list[senderRank] =
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection);
} }
} else { std::vector<DeviceHandle<mscclpp::SmChannel>> SmChannels;
connections[0] = comm.connectOnSetup(0, 0, mscclpp::Transport::SmChannel); for(size_t i = 0; i < slave_semaphore_list.size(); ++i)
{
SmChannels.push_back(mscclpp::deviceHandle(
mscclpp::SmChannel(slave_semaphore_list[i], remoteMemories[i].get(), src_data)));
}
hipError_t error =
hipMemcpyToSymbol(constSlaveSmChannels,
SmChannels.data(),
sizeof(DeviceHandle<mscclpp::SmChannel>) * SmChannels.size());
if(error != hipSuccess)
{
std::cerr << "Error locating data to constant memory" << std::endl;
return;
}
}
else
{
auto localMemory = comm.registerMemory(dst_data, dataSize, transport);
mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>> connection =
comm.connectOnSetup(slaveRank, 0, transport);
comm.sendMemoryOnSetup(localMemory, slaveRank, 0);
comm.setup();
auto sender_semaphore =
std::make_shared<mscclpp::SmDevice2DeviceSemaphore>(comm, connection.get());
auto tempSmChannel = mscclpp::SmChannel(sender_semaphore, localMemory, src_data);
DeviceHandle<mscclpp::SmChannel> SenderSmChannel = mscclpp::deviceHandle(tempSmChannel);
} }
} }
...@@ -158,8 +200,8 @@ struct AllocateAndTransferFunctor ...@@ -158,8 +200,8 @@ struct AllocateAndTransferFunctor
else else
{ {
const void* send_location_ptr = host_receive_ptr_future.get(); const void* send_location_ptr = host_receive_ptr_future.get();
args_send.p_send = send_location_ptr; args_send.p_send = send_location_ptr;
auto kargs_master = MasterKernel::MakeKargs( auto kargs_master = MasterKernel::MakeKargs(
args_send.p_reduce, args_send.p_send, args_send.M, args_send.N); args_send.p_reduce, args_send.p_send, args_send.M, args_send.N);
const dim3 grids_master = MasterKernel::GridSize(M, N); const dim3 grids_master = MasterKernel::GridSize(M, N);
ave_time = ck_tile::launch_kernel( ave_time = ck_tile::launch_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment