#include #include #include #include #include #include #include #include #include #include #include "bootstrap_net.h" namespace sccl { namespace hardware { namespace topology { namespace bootstrap { namespace bootstrapNet { /* Init functions */ // 用于存储网络接口名称的静态字符数组 static char bootstrapNetIfName[net::MAX_IF_NAME_SIZE + 1]; // 用于存储网络接口地址的静态结构体 static scclSocketAddress_t bootstrapNetIfAddr; // 静态整型变量,用于指示网络初始化是否已完成(0表示未完成,非0表示已完成) static int bootstrapNetInitDone = 0; // 互斥锁,用于保护对上述静态变量的访问,确保线程安全 static pthread_mutex_t bootstrapNetLock = PTHREAD_MUTEX_INITIALIZER; /** * @brief 初始化引导网络 * * 该函数用于初始化SCCL的引导网络。 * 如果设置了 NCCL_COMM_ID 环境变量,则查找一个和该环境变量中指定的 IP 地址处于同一子网的网卡作为 booststrap 网络通信所使用的网卡 bootstrapNetIfAddr * 否则,使用 ncclFindInterfaces 函数选择一个合适的网卡 * * 函数使用互斥锁确保线程安全。 * * @return scclResult_t 返回操作结果: * - scclSuccess: 初始化成功 * - scclInvalidArgument: 无效的SCCL_COMM_ID格式 * - scclSystemError: 找不到匹配的网络接口 * - scclInternalError: 找不到可用的网络接口 */ scclResult_t bootstrapNetInit() { if(bootstrapNetInitDone == 0) { pthread_mutex_lock(&bootstrapNetLock); if(bootstrapNetInitDone == 0) { char* env = getenv("SCCL_COMM_ID"); if(env) { scclSocketAddress_t remoteAddr; if(net::net_socket::scclSocketGetAddrFromString(&remoteAddr, env) != scclSuccess) { WARN("Invalid SCCL_COMM_ID, please use format: : or []: or :"); return scclInvalidArgument; } if(net::net_socket::scclFindInterfaceMatchSubnet(bootstrapNetIfName, &bootstrapNetIfAddr, &remoteAddr, net::MAX_IF_NAME_SIZE, 1) <= 0) { WARN("NET/Socket : No usable listening interface found"); return scclSystemError; } } else { int nIfs = net::net_socket::scclFindSocketInterfaces(bootstrapNetIfName, &bootstrapNetIfAddr, net::MAX_IF_NAME_SIZE, 1); if(nIfs <= 0) { WARN("Bootstrap : no socket interface found"); return scclInternalError; } } char line[net::SOCKET_NAME_MAXLEN + net::MAX_IF_NAME_SIZE + 2]; sprintf(line, "%s:", bootstrapNetIfName); net::net_socket::scclSocketToString(&bootstrapNetIfAddr, line + strlen(line)); INFO(SCCL_LOG_BOOTSTRAP, "Bootstrap : Using %s", line); bootstrapNetInitDone = 1; } pthread_mutex_unlock(&bootstrapNetLock); } return scclSuccess; } scclSocketAddress_t getLocalSocketAddr() { return bootstrapNetIfAddr; } // Additional sync functions /** * 通过网络发送数据 * * @param sock 已连接的socket指针 * @param data 要发送的数据指针 * @param size 要发送的数据大小(字节) * @return scclResult_t 返回操作结果(scclSuccess表示成功) * * @note 先发送数据大小(sizeof(int)),再发送实际数据 */ scclResult_t bootstrapNetSend(scclSocket_t* sock, void* data, int size) { SCCLCHECK(net::net_socket::scclSocketSend(sock, &size, sizeof(int))); SCCLCHECK(net::net_socket::scclSocketSend(sock, data, size)); return scclSuccess; } /** * 从socket接收数据 * * @param sock 要接收数据的socket * @param data 接收数据的缓冲区 * @param size 缓冲区大小 * @return scclResult_t 返回操作结果,成功返回scclSuccess,否则返回错误码 * * @note 如果接收到的数据大小超过缓冲区大小,会截断数据并返回scclInternalError */ scclResult_t bootstrapNetRecv(scclSocket_t* sock, void* data, int size) { int recvSize; SCCLCHECK(net::net_socket::scclSocketRecv(sock, &recvSize, sizeof(int))); if(recvSize > size) { WARN("Message truncated : received %d bytes instead of %d", recvSize, size); return scclInternalError; } SCCLCHECK(net::net_socket::scclSocketRecv(sock, data, std::min(recvSize, size))); return scclSuccess; } } // namespace bootstrapNet } // namespace bootstrap } // namespace topology } // namespace hardware } // namespace sccl