#include #include #include #include #include #include "socket.h" #include "net_socket.h" namespace sccl { namespace hardware { namespace net { namespace host { namespace net_socket { #define MAX_LINE_LEN (2047) /* Init functions */ static int scclNetIfs = -1; struct scclNetSocketDev { union scclSocketAddress addr; char devName[MAX_IF_NAME_SIZE]; char* pciPath; }; static struct scclNetSocketDev scclNetSocketDevs[MAX_IFS]; pthread_mutex_t scclNetSocketLock = PTHREAD_MUTEX_INITIALIZER; static scclResult_t scclNetSocketGetPciPath(char* devName, char** pciPath) { char devicePath[PATH_MAX]; snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName); // May return NULL if the file doesn't exist. *pciPath = realpath(devicePath, NULL); return scclSuccess; } scclResult_t scclNetSocketInit(void) { if(scclNetIfs == -1) { pthread_mutex_lock(&scclNetSocketLock); if(scclNetIfs == -1) { char names[MAX_IF_NAME_SIZE * MAX_IFS]; union scclSocketAddress addrs[MAX_IFS]; scclNetIfs = scclFindSocketInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS); if(scclNetIfs <= 0) { WARN("NET/Socket : no interface found"); return scclInternalError; } else { char line[MAX_LINE_LEN + 1]; char addrline[SOCKET_NAME_MAXLEN + 1]; line[0] = '\0'; addrline[SOCKET_NAME_MAXLEN] = '\0'; for(int i = 0; i < scclNetIfs; i++) { strcpy(scclNetSocketDevs[i].devName, names + i * MAX_IF_NAME_SIZE); memcpy(&scclNetSocketDevs[i].addr, addrs + i, sizeof(union scclSocketAddress)); SCCLCHECK(scclNetSocketGetPciPath(scclNetSocketDevs[i].devName, &scclNetSocketDevs[i].pciPath)); snprintf(line + strlen(line), MAX_LINE_LEN - strlen(line), " [%d]%s:%s", i, names + i * MAX_IF_NAME_SIZE, scclSocketToString(&addrs[i], addrline)); } line[MAX_LINE_LEN] = '\0'; INFO(SCCL_LOG_NET, "NET/Socket : Using%s", line); } } pthread_mutex_unlock(&scclNetSocketLock); } return scclSuccess; } scclResult_t scclNetSocketDevices(int* ndev) { *ndev = scclNetIfs; return scclSuccess; } static scclResult_t scclNetSocketGetSpeed(char* devName, int* speed) { *speed = 0; char speedPath[PATH_MAX]; sprintf(speedPath, "/sys/class/net/%s/speed", devName); int fd = open(speedPath, O_RDONLY); if(fd != -1) { char speedStr[] = " "; if(read(fd, speedStr, sizeof(speedStr) - 1) > 0) { *speed = strtol(speedStr, NULL, 0); } close(fd); } if(*speed <= 0) { INFO(SCCL_LOG_NET, "Could not get speed from %s. Defaulting to 10 Gbps.", speedPath); *speed = 10000; } return scclSuccess; } scclResult_t scclNetSocketGetProperties(int dev, scclNetProperties_t* props) { props->name = scclNetSocketDevs[dev].devName; props->pciPath = scclNetSocketDevs[dev].pciPath; props->guid = dev; props->ptrSupport = SCCL_PTR_HOST; SCCLCHECK(scclNetSocketGetSpeed(props->name, &props->speed)); props->latency = 0; // Not set props->port = 0; props->maxComms = 65536; props->maxRecvs = 1; return scclSuccess; } /* Communication functions */ #define MAX_SOCKETS 64 #define MAX_THREADS 16 #define MAX_REQUESTS SCCL_NET_MAX_REQUESTS #define MIN_CHUNKSIZE (64 * 1024) SCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2); SCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2); enum scclNetSocketCommState : uint8_t { scclNetSocketCommStateStart = 0, scclNetSocketCommStateConnect = 1, scclNetSocketCommStateAccept = 3, scclNetSocketCommStateSend = 4, scclNetSocketCommStateRecv = 5, }; struct scclNetSocketCommStage { enum scclNetSocketCommState state; uint8_t iteration; struct scclSocket* sock; struct scclNetSocketComm* comm; }; struct scclNetSocketHandle { union scclSocketAddress connectAddr; uint64_t magic; // random number to help debugging int nSocks; int nThreads; struct scclNetSocketCommStage stage; }; struct scclNetSocketTask { int op; void* data; int size; struct scclSocket* sock; int offset; int used; scclResult_t result; }; struct scclNetSocketRequest { int op; void* data; int size; struct scclSocket* ctrlSock; int offset; int used; struct scclNetSocketComm* comm; struct scclNetSocketTask* tasks[MAX_SOCKETS]; int nSubs; }; struct scclNetSocketTaskQueue { int next; int len; struct scclNetSocketTask* tasks; }; struct scclNetSocketThreadResources { struct scclNetSocketTaskQueue threadTaskQueue; int stop; struct scclNetSocketComm* comm; pthread_mutex_t threadLock; pthread_cond_t threadCond; }; struct scclNetSocketListenComm { struct scclSocket sock; struct scclNetSocketCommStage stage; int nSocks; int nThreads; int dev; }; struct scclNetSocketComm { struct scclSocket ctrlSock; struct scclSocket socks[MAX_SOCKETS]; int dev; int cudaDev; int nSocks; int nThreads; int nextSock; struct scclNetSocketRequest requests[MAX_REQUESTS]; pthread_t helperThread[MAX_THREADS]; struct scclNetSocketThreadResources threadResources[MAX_THREADS]; }; void* persistentSocketThread(void* args_) { struct scclNetSocketThreadResources* resource = (struct scclNetSocketThreadResources*)args_; struct scclNetSocketComm* comm = resource->comm; struct scclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue; int nSocksPerThread = comm->nSocks / comm->nThreads; while(1) { int idle = 1; int mark = myQueue->next; // mark newest task seen for(int i = 0; i < myQueue->len; i += nSocksPerThread) { int repeat; do { repeat = 0; for(int j = 0; j < nSocksPerThread; j++) { struct scclNetSocketTask* r = myQueue->tasks + i + j; if(r != NULL && r->used == 1 && r->offset < r->size) { r->result = scclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset); if(r->result != scclSuccess) { WARN("NET/Socket : socket progress error"); return NULL; } idle = 0; if(r->offset < r->size) repeat = 1; } } } while(repeat); } if(idle) { pthread_mutex_lock(&resource->threadLock); while(mark == myQueue->next && resource->stop == 0) { // no new tasks, wait pthread_cond_wait(&resource->threadCond, &resource->threadLock); } pthread_mutex_unlock(&resource->threadLock); } if(resource->stop) return NULL; } } scclResult_t scclNetSocketGetNsockNthread(int dev, int* ns, int* nt) { int nSocksPerThread = scclParamSocketNsocksPerThread(); int nThreads = scclParamSocketNthreads(); if(nThreads > MAX_THREADS) { WARN("NET/Socket : SCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS); nThreads = MAX_THREADS; } if(nThreads == -2 || nSocksPerThread == -2) { // Auto-detection int autoNt = 0, autoNs = 1; // By default, we only use the main thread and do not spawn extra threads char vendorPath[PATH_MAX]; snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", scclNetSocketDevs[dev].devName); char* rPath = realpath(vendorPath, NULL); int fd = open(rPath, O_RDONLY); free(rPath); if(fd == -1) { // Could not find device vendor. This is handled silently so // we don't want to print an INFO error. INFO(SCCL_LOG_NET, "Open of %s failed : %s", vendorPath, strerror(errno)); goto end; } char vendor[7]; strncpy(vendor, "0x0000", 7); int len; SYSCHECKVAL(read(fd, vendor, 6), "read", len); SYSCHECK(close(fd), "close"); if(strcmp(vendor, "0x1d0f") == 0) { // AWS autoNt = 2; autoNs = 8; } else if(strcmp(vendor, "0x1ae0") == 0) { // GCP autoNt = 4; autoNs = 1; } end: if(nThreads == -2) nThreads = autoNt; if(nSocksPerThread == -2) nSocksPerThread = autoNs; } int nSocks = nSocksPerThread * nThreads; if(nSocks > MAX_SOCKETS) { nSocksPerThread = MAX_SOCKETS / nThreads; WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting SCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread); nSocks = nSocksPerThread * nThreads; } *ns = nSocks; *nt = nThreads; if(nSocks > 0) INFO(SCCL_LOG_NET, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread); return scclSuccess; } scclResult_t scclNetSocketListen(int dev, void* opaqueHandle, void** listenComm) { if(dev < 0 || dev >= scclNetIfs) { // data transfer socket is based on specified dev return scclInternalError; } struct scclNetSocketHandle* handle = (struct scclNetSocketHandle*)opaqueHandle; memset(handle, 0, sizeof(struct scclNetSocketHandle)); static_assert(sizeof(struct scclNetSocketHandle) <= SCCL_NET_HANDLE_MAXSIZE, "scclNetSocketHandle size too large"); struct scclNetSocketListenComm* comm; SCCLCHECK(scclCalloc(&comm, 1)); handle->magic = SCCL_SOCKET_MAGIC; SCCLCHECK(scclSocketInit(&comm->sock, &scclNetSocketDevs[dev].addr, handle->magic, scclSocketTypeNetSocket, NULL, 1)); SCCLCHECK(scclSocketListen(&comm->sock)); SCCLCHECK(scclSocketGetAddr(&comm->sock, &handle->connectAddr)); SCCLCHECK(scclNetSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads)); handle->nSocks = comm->nSocks; handle->nThreads = comm->nThreads; comm->dev = dev; *listenComm = comm; return scclSuccess; } scclResult_t scclNetSocketConnect(int dev, void* opaqueHandle, void** sendComm) { if(dev < 0 || dev >= scclNetIfs) { // data transfer socket is based on specified dev return scclInternalError; } int ready; struct scclNetSocketHandle* handle = (struct scclNetSocketHandle*)opaqueHandle; struct scclNetSocketCommStage* stage = &handle->stage; struct scclNetSocketComm* comm = stage->comm; uint8_t i = stage->iteration; struct scclSocket* sock = stage->sock; *sendComm = NULL; if(stage->state == scclNetSocketCommStateConnect) goto socket_connect_check; if(stage->state == scclNetSocketCommStateSend) goto socket_send; SCCLCHECK(scclCalloc(&comm, 1)); stage->comm = comm; comm->nSocks = handle->nSocks; comm->nThreads = handle->nThreads; comm->dev = dev; HIPCHECK(hipGetDevice(&comm->cudaDev)); for(; i < comm->nSocks + 1; i++) { sock = (i == comm->nSocks) ? &comm->ctrlSock : comm->socks + i; SCCLCHECK(scclSocketInit(sock, &handle->connectAddr, handle->magic, scclSocketTypeNetSocket, NULL, 1)); stage->sock = sock; stage->state = scclNetSocketCommStateConnect; stage->iteration = i; SCCLCHECK(scclSocketConnect(sock)); socket_connect_check: SCCLCHECK(scclSocketReady(sock, &ready)); if(!ready) return scclSuccess; stage->state = scclNetSocketCommStateSend; socket_send: int done = 0; SCCLCHECK(scclSocketProgress(SCCL_SOCKET_SEND, sock, &i, sizeof(uint8_t), &done)); if(done == 0) return scclSuccess; } *sendComm = comm; return scclSuccess; } scclResult_t scclNetSocketAccept(void* listenComm, void** recvComm) { struct scclNetSocketListenComm* lComm = (struct scclNetSocketListenComm*)listenComm; struct scclNetSocketCommStage* stage = &lComm->stage; struct scclNetSocketComm* rComm = stage->comm; uint8_t i = stage->iteration; struct scclSocket* sock = stage->sock; int ready; *recvComm = NULL; if(stage->state == scclNetSocketCommStateAccept) goto socket_accept_check; if(stage->state == scclNetSocketCommStateRecv) goto socket_recv; SCCLCHECK(scclCalloc(&rComm, 1)); stage->comm = rComm; rComm->nSocks = lComm->nSocks; rComm->nThreads = lComm->nThreads; rComm->dev = lComm->dev; HIPCHECK(hipGetDevice(&rComm->cudaDev)); for(; i < rComm->nSocks + 1; i++) { uint8_t sendSockIdx; SCCLCHECK(scclCalloc(&sock, 1)); SCCLCHECK(scclSocketInit(sock)); stage->sock = sock; stage->state = scclNetSocketCommStateAccept; stage->iteration = i; SCCLCHECK(scclSocketAccept(sock, &lComm->sock)); socket_accept_check: SCCLCHECK(scclSocketReady(sock, &ready)); if(!ready) return scclSuccess; stage->state = scclNetSocketCommStateRecv; socket_recv: int done = 0; SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &sendSockIdx, sizeof(uint8_t), &done)); if(done == 0) return scclSuccess; if(sendSockIdx == rComm->nSocks) memcpy(&rComm->ctrlSock, sock, sizeof(struct scclSocket)); else memcpy(rComm->socks + sendSockIdx, sock, sizeof(struct scclSocket)); free(sock); } *recvComm = rComm; /* reset lComm state */ stage->state = scclNetSocketCommStateStart; stage->iteration = 0; stage->sock = NULL; stage->comm = NULL; return scclSuccess; } scclResult_t scclNetSocketGetRequest(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketRequest** req) { for(int i = 0; i < MAX_REQUESTS; i++) { struct scclNetSocketRequest* r = comm->requests + i; if(r->used == 0) { r->op = op; r->data = data; r->size = size; r->ctrlSock = &comm->ctrlSock; r->used = 1; r->comm = comm; r->nSubs = 0; *req = r; return scclSuccess; } } WARN("NET/Socket : unable to allocate requests"); return scclInternalError; } scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req) { int tid = comm->nextSock % comm->nThreads; struct scclNetSocketThreadResources* res = comm->threadResources + tid; struct scclNetSocketTaskQueue* queue = &res->threadTaskQueue; // create helper threads and prepare per-thread task queue if(queue->tasks == NULL) { // each request can be divided up to nSocks tasks, and // these tasks are distributed to nThreads threads, // we need to make sure each thread queue has enough slots for MAX_REQUESTS queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads); SCCLCHECK(scclCalloc(&queue->tasks, queue->len)); queue->next = 0; res->comm = comm; pthread_mutex_init(&res->threadLock, NULL); pthread_cond_init(&res->threadCond, NULL); pthread_create(comm->helperThread + tid, NULL, persistentSocketThread, res); scclSetThreadName(comm->helperThread[tid], "SCCL Sock%c%1u%2u%2u", op == SCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->cudaDev); } struct scclNetSocketTask* r = queue->tasks + queue->next; if(r->used == 0) { r->op = op; r->data = data; r->size = size; r->sock = comm->socks + comm->nextSock; r->offset = 0; r->result = scclSuccess; comm->nextSock = (comm->nextSock + 1) % comm->nSocks; r->used = 1; *req = r; pthread_mutex_lock(&res->threadLock); queue->next = (queue->next + 1) % queue->len; pthread_cond_signal(&res->threadCond); pthread_mutex_unlock(&res->threadLock); return scclSuccess; } WARN("NET/Socket : unable to allocate subtasks"); return scclInternalError; } scclResult_t scclNetSocketTest(void* request, int* done, int* size) { *done = 0; struct scclNetSocketRequest* r = (struct scclNetSocketRequest*)request; if(r == NULL) { WARN("NET/Socket : test called with NULL request"); return scclInternalError; } if(r->used == 1) { /* try to send/recv size */ int data = r->size; int offset = 0; SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, &data, sizeof(int), &offset)); if(offset == 0) return scclSuccess; /* Not ready -- retry later */ // Not sure we could ever receive less than 4 bytes, but just in case ... if(offset < sizeof(int)) SCCLCHECK(scclSocketWait(r->op, r->ctrlSock, &data, sizeof(int), &offset)); // Check size is less or equal to the size provided by the user if(r->op == SCCL_SOCKET_RECV && data > r->size) { char line[SOCKET_NAME_MAXLEN + 1]; union scclSocketAddress addr; scclSocketGetAddr(r->ctrlSock, &addr); WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \ there may be a mismatch in collective sizes or environment settings (e.g. SCCL_PROTO, SCCL_ALGO) between ranks", scclSocketToString(&addr, line), data, r->size); return scclInvalidUsage; } r->size = data; r->offset = 0; r->used = 2; // done exchanging size // divide into subtasks int chunkOffset = 0, i = 0; if(r->comm->nSocks > 0) { // each request can be divided up to nSocks tasks int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks)); while(chunkOffset < r->size) { int chunkSize = std::min(taskSize, r->size - chunkOffset); SCCLCHECK(scclNetSocketGetTask(r->comm, r->op, (char*)(r->data) + chunkOffset, chunkSize, r->tasks + i++)); chunkOffset += chunkSize; } } r->nSubs = i; } if(r->used == 2) { // already exchanged size if(r->nSubs > 0) { int nCompleted = 0; for(int i = 0; i < r->nSubs; i++) { struct scclNetSocketTask* sub = r->tasks[i]; if(sub->result != scclSuccess) return sub->result; if(sub->offset == sub->size) nCompleted++; } if(nCompleted == r->nSubs) { if(size) *size = r->size; *done = 1; r->used = 0; for(int i = 0; i < r->nSubs; i++) { struct scclNetSocketTask* sub = r->tasks[i]; sub->used = 0; } } } else { // progress request using main thread if(r->offset < r->size) { SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, r->data, r->size, &r->offset)); } if(r->offset == r->size) { if(size) *size = r->size; *done = 1; r->used = 0; } } } return scclSuccess; } scclResult_t scclNetSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) { return (type != SCCL_PTR_HOST) ? scclInternalError : scclSuccess; } scclResult_t scclNetSocketDeregMr(void* comm, void* mhandle) { return scclSuccess; } scclResult_t scclNetSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { struct scclNetSocketComm* comm = (struct scclNetSocketComm*)sendComm; SCCLCHECK(scclNetSocketGetRequest(comm, SCCL_SOCKET_SEND, data, size, (struct scclNetSocketRequest**)request)); return scclSuccess; } scclResult_t scclNetSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { struct scclNetSocketComm* comm = (struct scclNetSocketComm*)recvComm; if(n != 1) return scclInternalError; SCCLCHECK(scclNetSocketGetRequest(comm, SCCL_SOCKET_RECV, data[0], sizes[0], (struct scclNetSocketRequest**)request)); return scclSuccess; } scclResult_t scclNetSocketIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { // We don't support HIP pointers, so we don't need a flush operation return scclInternalError; } scclResult_t scclNetSocketCloseListen(void* opaqueComm) { struct scclNetSocketListenComm* comm = (struct scclNetSocketListenComm*)opaqueComm; if(comm) { int ready; SCCLCHECK(scclSocketReady(&comm->sock, &ready)); if(ready) SCCLCHECK(scclSocketClose(&comm->sock)); free(comm); } return scclSuccess; } scclResult_t scclNetSocketClose(void* opaqueComm) { struct scclNetSocketComm* comm = (struct scclNetSocketComm*)opaqueComm; if(comm) { for(int i = 0; i < comm->nThreads; i++) { struct scclNetSocketThreadResources* res = comm->threadResources + i; if(comm->helperThread[i]) { pthread_mutex_lock(&res->threadLock); res->stop = 1; pthread_cond_signal(&res->threadCond); pthread_mutex_unlock(&res->threadLock); pthread_join(comm->helperThread[i], NULL); } free(res->threadTaskQueue.tasks); } int ready; SCCLCHECK(scclSocketReady(&comm->ctrlSock, &ready)); if(ready) SCCLCHECK(scclSocketClose(&comm->ctrlSock)); for(int i = 0; i < comm->nSocks; i++) { SCCLCHECK(scclSocketReady(&comm->socks[i], &ready)); if(ready) SCCLCHECK(scclSocketClose(&comm->socks[i])); } free(comm); } return scclSuccess; } } // namespace net_socket scclNet_t scclNetSocket = {"Socket", net_socket::scclNetSocketInit, net_socket::scclNetSocketDevices, net_socket::scclNetSocketGetProperties, net_socket::scclNetSocketListen, net_socket::scclNetSocketConnect, net_socket::scclNetSocketAccept, net_socket::scclNetSocketRegMr, NULL, // No DMA-BUF support net_socket::scclNetSocketDeregMr, net_socket::scclNetSocketIsend, net_socket::scclNetSocketIrecv, net_socket::scclNetSocketIflush, net_socket::scclNetSocketTest, net_socket::scclNetSocketClose, net_socket::scclNetSocketClose, net_socket::scclNetSocketCloseListen}; } // namespace host } // namespace net } // namespace hardware } // namespace sccl