#pragma once #include #include #include #include #include #include #include #include #include "base.h" #include "net_utils.h" #include "socket.h" namespace sccl { namespace hardware { namespace net { namespace net_socket { /* Communication functions */ static constexpr int MAX_SOCKETS = 64; static constexpr int MAX_THREADS = 16; static constexpr int MAX_REQUESTS = SCCL_NET_MAX_REQUESTS; static constexpr int MIN_CHUNKSIZE = (64 * 1024); enum scclNetSocketCommState : uint8_t { scclNetSocketCommStateStart = 0, scclNetSocketCommStateConnect = 1, scclNetSocketCommStateAccept = 3, scclNetSocketCommStateSend = 4, scclNetSocketCommStateRecv = 5, }; struct scclNetSocketCommStage { enum scclNetSocketCommState state; uint8_t iteration; struct scclSocket* sock; struct scclNetSocketComm* comm = nullptr; }; struct scclNetSocketHandle { union scclSocketAddress connectAddr; uint64_t magic; // random number to help debugging int nSocks; int nThreads; struct scclNetSocketCommStage stage; }; struct scclNetSocketTask { int op; void* data; int size; struct scclSocket* sock = nullptr; int offset; int used; scclResult_t result; }; struct scclNetSocketRequest { int op; void* data; int size; struct scclSocket* ctrlSock = nullptr; int offset; int used; struct scclNetSocketComm* comm = nullptr; struct scclNetSocketTask* tasks[MAX_SOCKETS] = {nullptr}; int nSubs; }; struct scclNetSocketTaskQueue { int next; int len; struct scclNetSocketTask* tasks = nullptr; }; struct scclNetSocketThreadResources { struct scclNetSocketTaskQueue threadTaskQueue; int stop; struct scclNetSocketComm* comm = nullptr; pthread_mutex_t threadLock; pthread_cond_t threadCond; }; struct scclNetSocketListenComm { struct scclSocket sock; struct scclNetSocketCommStage stage; int nSocks; int nThreads; int dev; }; struct scclNetSocketComm { struct scclSocket ctrlSock; struct scclSocket socks[MAX_SOCKETS]; int dev; int hipDev; int nSocks; int nThreads; int nextSock; struct scclNetSocketRequest requests[MAX_REQUESTS]; pthread_t helperThread[MAX_THREADS]; struct scclNetSocketThreadResources threadResources[MAX_THREADS]; }; ////////////////////////////////// class scclNetSocket : public scclNetBase { public: // 构造函数和析构函数 scclNetSocket(); virtual ~scclNetSocket(); // 初始化网络。 scclResult_t init() override; // 返回适配器的数量。 scclResult_t devices(int* ndev) override; // 获取各种设备属性。 scclResult_t getProperties(int dev, scclNetProperties_t* props) override; // 创建一个接收对象并提供一个句柄以连接到它。该句柄最多可以是 SCCL_NET_HANDLE_MAXSIZE 字节,并将在排名之间交换以创建连接。 scclResult_t listen(int dev, void* handle, void** listenComm) override; // 连接到一个句柄并返回一个发送 comm 对象给该对等体。 // 此调用不应阻塞以建立连接,而应成功返回 sendComm == NULL,并期望再次调用直到 sendComm != NULL。 scclResult_t connect(int dev, void* handle, void** sendComm) override; // 在远程对等体调用 connect 后最终确定连接建立。 // 此调用不应阻塞以建立连接,而应成功返回 recvComm == NULL,并期望再次调用直到 recvComm != NULL。 scclResult_t accept(void* listenComm, void** recvComm) override; // 注册/注销内存。Comm 可以是 sendComm 或 recvComm。 // 类型是 SCCL_PTR_HOST 或 SCCL_PTR_CUDA。 scclResult_t regMr(void* comm, void* data, int size, int type, void** mhandle) override; /* DMA-BUF 支持 */ scclResult_t regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) override; // 注销IB内存区域(MR) scclResult_t deregMr(void* comm, void* mhandle) override; // 异步发送到对等体。 // 如果调用不能执行(或会阻塞),则可能返回 request == NULL scclResult_t isend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) override; // 异步从对等体接收。 如果调用不能执行(或会阻塞),则可能返回 request == NULL scclResult_t irecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) override; // 执行刷新/栅栏操作,以确保所有使用 SCCL_PTR_CUDA 接收到的数据对 GPU 可见 scclResult_t iflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) override; // 测试请求是否完成。如果 size 不为 NULL,则返回发送/接收的字节数。 scclResult_t test(void* request, int* done, int* sizes) override; // 关闭并释放 send/recv comm 对象 scclResult_t closeSend(void* sendComm) override; scclResult_t closeRecv(void* recvComm) override; scclResult_t closeListen(void* listenComm) override; private: struct scclNetSocketListenComm* socketComm = nullptr; private: // 获取网络设备的PCI路径 static scclResult_t scclNetSocketGetPciPath(char* devName, char** pciPath); // 获取指定网络设备的速度(单位:Mbps) scclResult_t scclNetSocketGetSpeed(char* devName, int* speed); // 持久化socket线程处理函数 static void* persistentSocketThread(void* args_); // 为指定通信对象创建并获取一个网络套接字任务 scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req); // 获取指定设备的socket和线程数量配置 scclResult_t scclNetSocketGetNsockNthread(int dev, int* ns, int* nt); }; } // namespace net_socket } // namespace net } // namespace hardware } // namespace sccl