#include #include #include #include #include #include "socket.h" #include "net_socket.h" namespace sccl { namespace hardware { namespace net { namespace net_socket { #define MAX_LINE_LEN (2047) /* Init functions */ static int scclNetIfs = -1; struct scclNetSocketDev { union scclSocketAddress addr; char devName[MAX_IF_NAME_SIZE]; char* pciPath; }; static struct scclNetSocketDev scclNetSocketDevs[MAX_IFS]; pthread_mutex_t scclNetSocketLock = PTHREAD_MUTEX_INITIALIZER; SCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2); SCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2); //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////// scclNetSocket调用的函数 //////////////////////////////////////// scclNetSocket::scclNetSocket() : scclNetBase("Socket") {} scclNetSocket::~scclNetSocket() { if(socketComm != nullptr) { free(socketComm); } } /** * 获取网络设备的PCI路径 * * @param devName 网络设备名称 * @param pciPath 输出参数,用于存储PCI路径的指针 * @return 返回操作结果(scclSuccess表示成功) * * @note 如果设备不存在,pciPath可能返回NULL */ scclResult_t scclNetSocket::scclNetSocketGetPciPath(char* devName, char** pciPath) { char devicePath[PATH_MAX]; snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName); // May return NULL if the file doesn't exist. *pciPath = realpath(devicePath, NULL); return scclSuccess; } scclResult_t scclNetSocket::init() { SCCLCHECK(scclMalloc(&socketComm, 1)); if(scclNetIfs == -1) { pthread_mutex_lock(&scclNetSocketLock); if(scclNetIfs == -1) { char names[MAX_IF_NAME_SIZE * MAX_IFS]; union scclSocketAddress addrs[MAX_IFS]; scclNetIfs = scclFindSocketInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS); if(scclNetIfs <= 0) { WARN("NET/Socket : no interface found"); return scclInternalError; } else { char line[MAX_LINE_LEN + 1]; char addrline[SOCKET_NAME_MAXLEN + 1]; line[0] = '\0'; addrline[SOCKET_NAME_MAXLEN] = '\0'; for(int i = 0; i < scclNetIfs; i++) { strcpy(scclNetSocketDevs[i].devName, names + i * MAX_IF_NAME_SIZE); memcpy(&scclNetSocketDevs[i].addr, addrs + i, sizeof(union scclSocketAddress)); SCCLCHECK(scclNetSocketGetPciPath(scclNetSocketDevs[i].devName, &scclNetSocketDevs[i].pciPath)); snprintf(line + strlen(line), MAX_LINE_LEN - strlen(line), " [%d]%s:%s", i, names + i * MAX_IF_NAME_SIZE, scclSocketToString(&addrs[i], addrline)); } line[MAX_LINE_LEN] = '\0'; INFO(SCCL_LOG_NET, "NET/Socket : Using%s", line); } } pthread_mutex_unlock(&scclNetSocketLock); } return scclSuccess; } scclResult_t scclNetSocket::devices(int* ndev) { *ndev = scclNetIfs; return scclSuccess; } /** * @brief 获取指定网络设备的速度(单位:Mbps) * * 该函数通过读取/sys/class/net/<设备名>/speed文件来获取网络设备的速度。 * 如果读取失败或速度为0,则默认返回10Gbps(10000Mbps)。 * * @param devName 网络设备名称 * @param speed 输出参数,用于存储获取到的速度值 * @return scclResult_t 始终返回scclSuccess表示成功 */ scclResult_t scclNetSocket::scclNetSocketGetSpeed(char* devName, int* speed) { *speed = 0; char speedPath[PATH_MAX]; sprintf(speedPath, "/sys/class/net/%s/speed", devName); int fd = open(speedPath, O_RDONLY); if(fd != -1) { char speedStr[] = " "; if(read(fd, speedStr, sizeof(speedStr) - 1) > 0) { *speed = strtol(speedStr, NULL, 0); } close(fd); } if(*speed <= 0) { INFO(SCCL_LOG_NET, "Could not get speed from %s. Defaulting to 10 Gbps.", speedPath); *speed = 10000; } return scclSuccess; } /** * @brief 获取网络套接字设备的属性 * * @param dev 设备索引 * @param props 用于存储设备属性的结构体指针 * @return scclResult_t 返回操作结果,scclSuccess表示成功 * * 该函数用于填充指定网络设备的属性信息,包括设备名称、PCI路径、速度等。 * 注意:延迟(latency)和端口(port)属性当前未设置。 */ scclResult_t scclNetSocket::getProperties(int dev, scclNetProperties_t* props) { props->name = scclNetSocketDevs[dev].devName; props->pciPath = scclNetSocketDevs[dev].pciPath; props->guid = dev; props->ptrSupport = SCCL_PTR_HOST; SCCLCHECK(scclNetSocketGetSpeed(props->name, &props->speed)); props->latency = 0; // Not set props->port = 0; props->maxComms = 65536; props->maxRecvs = 1; return scclSuccess; } /** * @brief 持久化socket线程处理函数 * * 该线程持续处理socket任务队列中的任务,每个线程负责处理nSocksPerThread个socket。 * 当任务队列为空时,线程会等待条件变量通知;当收到停止信号时,线程退出。 * * @param args_ 线程参数,包含通信结构、任务队列和同步原语 * @return void* 总是返回NULL * * @note 线程会循环处理任务直到收到停止信号 * @warning 如果socket处理出错,线程会直接退出并打印警告信息 */ void* scclNetSocket::persistentSocketThread(void* args_) { struct scclNetSocketThreadResources* resource = (struct scclNetSocketThreadResources*)args_; struct scclNetSocketComm* comm = resource->comm; struct scclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue; int nSocksPerThread = comm->nSocks / comm->nThreads; while(1) { int idle = 1; int mark = myQueue->next; // mark newest task seen for(int i = 0; i < myQueue->len; i += nSocksPerThread) { int repeat; do { repeat = 0; for(int j = 0; j < nSocksPerThread; j++) { struct scclNetSocketTask* r = myQueue->tasks + i + j; if(r != NULL && r->used == 1 && r->offset < r->size) { r->result = scclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset); if(r->result != scclSuccess) { WARN("NET/Socket : socket progress error"); return NULL; } idle = 0; if(r->offset < r->size) repeat = 1; } } } while(repeat); } if(idle) { pthread_mutex_lock(&resource->threadLock); while(mark == myQueue->next && resource->stop == 0) { // no new tasks, wait pthread_cond_wait(&resource->threadCond, &resource->threadLock); } pthread_mutex_unlock(&resource->threadLock); } if(resource->stop) return NULL; } } /** * @brief 获取指定设备的socket和线程数量配置 * * 根据设备类型和参数配置,自动检测或设置每个线程的socket数量和线程数量。 * 支持AWS和GCP设备的自动检测,并确保配置不超过最大限制。 * * @param dev 设备索引 * @param ns 输出参数,返回总socket数量 * @param nt 输出参数,返回线程数量 * @return scclResult_t 返回操作结果,scclSuccess表示成功 */ scclResult_t scclNetSocket::scclNetSocketGetNsockNthread(int dev, int* ns, int* nt) { int nSocksPerThread = scclParamSocketNsocksPerThread(); int nThreads = scclParamSocketNthreads(); if(nThreads > MAX_THREADS) { WARN("NET/Socket : SCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS); nThreads = MAX_THREADS; } if(nThreads == -2 || nSocksPerThread == -2) { // Auto-detection int autoNt = 0, autoNs = 1; // By default, we only use the main thread and do not spawn extra threads char vendorPath[PATH_MAX]; snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", scclNetSocketDevs[dev].devName); char* rPath = realpath(vendorPath, NULL); int fd = open(rPath, O_RDONLY); free(rPath); if(fd == -1) { // Could not find device vendor. This is handled silently so // we don't want to print an INFO error. INFO(SCCL_LOG_NET, "Open of %s failed : %s", vendorPath, strerror(errno)); goto end; } char vendor[7]; strncpy(vendor, "0x0000", 7); int len; SYSCHECKVAL(read(fd, vendor, 6), "read", len); SYSCHECK(close(fd), "close"); if(strcmp(vendor, "0x1d0f") == 0) { // AWS autoNt = 2; autoNs = 8; } else if(strcmp(vendor, "0x1ae0") == 0) { // GCP autoNt = 4; autoNs = 1; } end: if(nThreads == -2) nThreads = autoNt; if(nSocksPerThread == -2) nSocksPerThread = autoNs; } int nSocks = nSocksPerThread * nThreads; if(nSocks > MAX_SOCKETS) { nSocksPerThread = MAX_SOCKETS / nThreads; WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting SCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread); nSocks = nSocksPerThread * nThreads; } *ns = nSocks; *nt = nThreads; if(nSocks > 0) INFO(SCCL_LOG_NET, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread); return scclSuccess; } scclResult_t scclNetSocket::listen(int dev, void* opaqueHandle, void** listenComm) { if(dev < 0 || dev >= scclNetIfs) { // data transfer socket is based on specified dev return scclInternalError; } struct scclNetSocketHandle* handle = (struct scclNetSocketHandle*)opaqueHandle; memset(handle, 0, sizeof(struct scclNetSocketHandle)); static_assert(sizeof(struct scclNetSocketHandle) <= SCCL_NET_HANDLE_MAXSIZE, "scclNetSocketHandle size too large"); memset(socketComm, 0, sizeof(struct scclNetSocketListenComm)); handle->magic = SCCL_SOCKET_MAGIC; SCCLCHECK(scclSocketInit(&socketComm->sock, &scclNetSocketDevs[dev].addr, handle->magic, scclSocketTypeNetSocket, NULL, 1)); SCCLCHECK(scclSocketListen(&socketComm->sock)); SCCLCHECK(scclSocketGetAddr(&socketComm->sock, &handle->connectAddr)); SCCLCHECK(scclNetSocketGetNsockNthread(dev, &socketComm->nSocks, &socketComm->nThreads)); handle->nSocks = socketComm->nSocks; handle->nThreads = socketComm->nThreads; socketComm->dev = dev; *listenComm = socketComm; return scclSuccess; } scclResult_t scclNetSocket::connect(int dev, void* opaqueHandle, void** sendComm) { if(dev < 0 || dev >= scclNetIfs) { // data transfer socket is based on specified dev return scclInternalError; } int ready; struct scclNetSocketHandle* handle = (struct scclNetSocketHandle*)opaqueHandle; struct scclNetSocketCommStage* stage = &handle->stage; struct scclNetSocketComm* comm = stage->comm; uint8_t i = stage->iteration; struct scclSocket* sock = stage->sock; *sendComm = NULL; if(stage->state == scclNetSocketCommStateConnect) goto socket_connect_check; if(stage->state == scclNetSocketCommStateSend) goto socket_send; SCCLCHECK(scclCalloc(&comm, 1)); stage->comm = comm; comm->nSocks = handle->nSocks; comm->nThreads = handle->nThreads; comm->dev = dev; HIPCHECK(hipGetDevice(&comm->hipDev)); for(; i < comm->nSocks + 1; i++) { sock = (i == comm->nSocks) ? &comm->ctrlSock : comm->socks + i; SCCLCHECK(scclSocketInit(sock, &handle->connectAddr, handle->magic, scclSocketTypeNetSocket, NULL, 1)); stage->sock = sock; stage->state = scclNetSocketCommStateConnect; stage->iteration = i; SCCLCHECK(scclSocketConnect(sock)); socket_connect_check: SCCLCHECK(scclSocketReady(sock, &ready)); if(!ready) return scclSuccess; stage->state = scclNetSocketCommStateSend; socket_send: int done = 0; SCCLCHECK(scclSocketProgress(SCCL_SOCKET_SEND, sock, &i, sizeof(uint8_t), &done)); if(done == 0) return scclSuccess; } *sendComm = comm; return scclSuccess; } scclResult_t scclNetSocket::accept(void* listenComm, void** recvComm) { struct scclNetSocketListenComm* lComm = (struct scclNetSocketListenComm*)listenComm; struct scclNetSocketCommStage* stage = &lComm->stage; struct scclNetSocketComm* rComm = stage->comm; uint8_t i = stage->iteration; struct scclSocket* sock = stage->sock; int ready; *recvComm = NULL; if(stage->state == scclNetSocketCommStateAccept) goto socket_accept_check; if(stage->state == scclNetSocketCommStateRecv) goto socket_recv; SCCLCHECK(scclCalloc(&rComm, 1)); stage->comm = rComm; rComm->nSocks = lComm->nSocks; rComm->nThreads = lComm->nThreads; rComm->dev = lComm->dev; HIPCHECK(hipGetDevice(&rComm->hipDev)); for(; i < rComm->nSocks + 1; i++) { uint8_t sendSockIdx; SCCLCHECK(scclCalloc(&sock, 1)); SCCLCHECK(scclSocketInit(sock)); stage->sock = sock; stage->state = scclNetSocketCommStateAccept; stage->iteration = i; SCCLCHECK(scclSocketAccept(sock, &lComm->sock)); socket_accept_check: SCCLCHECK(scclSocketReady(sock, &ready)); if(!ready) return scclSuccess; stage->state = scclNetSocketCommStateRecv; socket_recv: int done = 0; SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &sendSockIdx, sizeof(uint8_t), &done)); if(done == 0) return scclSuccess; if(sendSockIdx == rComm->nSocks) memcpy(&rComm->ctrlSock, sock, sizeof(struct scclSocket)); else memcpy(rComm->socks + sendSockIdx, sock, sizeof(struct scclSocket)); free(sock); } *recvComm = rComm; /* reset lComm state */ stage->state = scclNetSocketCommStateStart; stage->iteration = 0; stage->sock = NULL; stage->comm = NULL; return scclSuccess; } scclResult_t scclNetSocketGetRequest(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketRequest** req) { for(int i = 0; i < MAX_REQUESTS; i++) { struct scclNetSocketRequest* r = comm->requests + i; if(r->used == 0) { r->op = op; r->data = data; r->size = size; r->ctrlSock = &comm->ctrlSock; r->used = 1; r->comm = comm; r->nSubs = 0; *req = r; return scclSuccess; } } WARN("NET/Socket : unable to allocate requests"); return scclInternalError; } scclResult_t scclNetSocket::regMr(void* comm, void* data, int size, int type, void** mhandle) { return (type != SCCL_PTR_HOST) ? scclInternalError : scclSuccess; } scclResult_t scclNetSocket::regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { WARN("NET/Socket : unable to check DMA-BUF support"); return scclSuccess; } scclResult_t scclNetSocket::deregMr(void* comm, void* mhandle) { return scclSuccess; } scclResult_t scclNetSocket::isend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { struct scclNetSocketComm* comm = (struct scclNetSocketComm*)sendComm; SCCLCHECK(scclNetSocketGetRequest(comm, SCCL_SOCKET_SEND, data, size, (struct scclNetSocketRequest**)request)); return scclSuccess; } scclResult_t scclNetSocket::irecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { struct scclNetSocketComm* comm = (struct scclNetSocketComm*)recvComm; if(n != 1) return scclInternalError; SCCLCHECK(scclNetSocketGetRequest(comm, SCCL_SOCKET_RECV, data[0], sizes[0], (struct scclNetSocketRequest**)request)); return scclSuccess; } scclResult_t scclNetSocket::iflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { // We don't support HIP pointers, so we don't need a flush operation return scclInternalError; } /** * 为指定通信对象创建并获取一个网络套接字任务 * * @param comm 网络套接字通信对象指针 * @param op 操作类型(SCCL_SOCKET_SEND/SCCL_SOCKET_RECV) * @param data 任务数据缓冲区指针 * @param size 数据大小 * @param req [out] 返回创建的任务指针 * * @return 成功返回scclSuccess,失败返回scclInternalError * * @note 该函数会初始化线程资源(首次调用时),创建持久化线程处理任务队列 * @warning 当任务队列已满时会返回错误并打印警告 */ scclResult_t scclNetSocket::scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req) { int tid = comm->nextSock % comm->nThreads; struct scclNetSocketThreadResources* res = comm->threadResources + tid; struct scclNetSocketTaskQueue* queue = &res->threadTaskQueue; // create helper threads and prepare per-thread task queue if(queue->tasks == NULL) { // each request can be divided up to nSocks tasks, and // these tasks are distributed to nThreads threads, // we need to make sure each thread queue has enough slots for MAX_REQUESTS queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads); SCCLCHECK(scclCalloc(&queue->tasks, queue->len)); queue->next = 0; res->comm = comm; pthread_mutex_init(&res->threadLock, NULL); pthread_cond_init(&res->threadCond, NULL); pthread_create(comm->helperThread + tid, NULL, persistentSocketThread, res); scclSetThreadName(comm->helperThread[tid], "SCCL Sock%c%1u%2u%2u", op == SCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->hipDev); } struct scclNetSocketTask* r = queue->tasks + queue->next; if(r->used == 0) { r->op = op; r->data = data; r->size = size; r->sock = comm->socks + comm->nextSock; r->offset = 0; r->result = scclSuccess; comm->nextSock = (comm->nextSock + 1) % comm->nSocks; r->used = 1; *req = r; pthread_mutex_lock(&res->threadLock); queue->next = (queue->next + 1) % queue->len; pthread_cond_signal(&res->threadCond); pthread_mutex_unlock(&res->threadLock); return scclSuccess; } WARN("NET/Socket : unable to allocate subtasks"); return scclInternalError; } scclResult_t scclNetSocket::test(void* request, int* done, int* size) { *done = 0; struct scclNetSocketRequest* r = (struct scclNetSocketRequest*)request; if(r == NULL) { WARN("NET/Socket : test called with NULL request"); return scclInternalError; } if(r->used == 1) { /* try to send/recv size */ int data = r->size; int offset = 0; SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, &data, sizeof(int), &offset)); if(offset == 0) return scclSuccess; /* Not ready -- retry later */ // Not sure we could ever receive less than 4 bytes, but just in case ... if(offset < sizeof(int)) SCCLCHECK(scclSocketWait(r->op, r->ctrlSock, &data, sizeof(int), &offset)); // Check size is less or equal to the size provided by the user if(r->op == SCCL_SOCKET_RECV && data > r->size) { char line[SOCKET_NAME_MAXLEN + 1]; union scclSocketAddress addr; scclSocketGetAddr(r->ctrlSock, &addr); WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \ there may be a mismatch in collective sizes or environment settings (e.g. SCCL_PROTO, SCCL_ALGO) between ranks", scclSocketToString(&addr, line), data, r->size); return scclInvalidUsage; } r->size = data; r->offset = 0; r->used = 2; // done exchanging size // divide into subtasks int chunkOffset = 0, i = 0; if(r->comm->nSocks > 0) { // each request can be divided up to nSocks tasks int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks)); while(chunkOffset < r->size) { int chunkSize = std::min(taskSize, r->size - chunkOffset); SCCLCHECK(scclNetSocketGetTask(r->comm, r->op, (char*)(r->data) + chunkOffset, chunkSize, r->tasks + i++)); chunkOffset += chunkSize; } } r->nSubs = i; } if(r->used == 2) { // already exchanged size if(r->nSubs > 0) { int nCompleted = 0; for(int i = 0; i < r->nSubs; i++) { struct scclNetSocketTask* sub = r->tasks[i]; if(sub->result != scclSuccess) return sub->result; if(sub->offset == sub->size) nCompleted++; } if(nCompleted == r->nSubs) { if(size) *size = r->size; *done = 1; r->used = 0; for(int i = 0; i < r->nSubs; i++) { struct scclNetSocketTask* sub = r->tasks[i]; sub->used = 0; } } } else { // progress request using main thread if(r->offset < r->size) { SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, r->data, r->size, &r->offset)); } if(r->offset == r->size) { if(size) *size = r->size; *done = 1; r->used = 0; } } } return scclSuccess; } scclResult_t scclNetSocket::closeSend(void* opaqueComm) { struct scclNetSocketComm* comm = (struct scclNetSocketComm*)opaqueComm; if(comm) { for(int i = 0; i < comm->nThreads; i++) { struct scclNetSocketThreadResources* res = comm->threadResources + i; if(comm->helperThread[i]) { pthread_mutex_lock(&res->threadLock); res->stop = 1; pthread_cond_signal(&res->threadCond); pthread_mutex_unlock(&res->threadLock); pthread_join(comm->helperThread[i], NULL); } free(res->threadTaskQueue.tasks); } int ready; SCCLCHECK(scclSocketReady(&comm->ctrlSock, &ready)); if(ready) SCCLCHECK(scclSocketClose(&comm->ctrlSock)); for(int i = 0; i < comm->nSocks; i++) { SCCLCHECK(scclSocketReady(&comm->socks[i], &ready)); if(ready) SCCLCHECK(scclSocketClose(&comm->socks[i])); } free(comm); } return scclSuccess; } scclResult_t scclNetSocket::closeRecv(void* opaqueComm) { return closeSend(opaqueComm); } scclResult_t scclNetSocket::closeListen(void* opaqueComm) { struct scclNetSocketListenComm* comm = (struct scclNetSocketListenComm*)opaqueComm; if(comm) { int ready; SCCLCHECK(scclSocketReady(&comm->sock, &ready)); if(ready) SCCLCHECK(scclSocketClose(&comm->sock)); free(comm); } return scclSuccess; } } // namespace net_socket } // namespace net } // namespace hardware } // namespace sccl