Commit ecf8df33 authored by lishen's avatar lishen
Browse files

优化ipcsocket函数,支持更长的数据传输。修改线程池易卡死的bug

parent 85db7de4
...@@ -41,6 +41,7 @@ scclIpcSocket::~scclIpcSocket() { ...@@ -41,6 +41,7 @@ scclIpcSocket::~scclIpcSocket() {
while(!pthread_pool->allTasksCompleted()) { while(!pthread_pool->allTasksCompleted()) {
usleep(1000); // 每1毫秒检查一次任务完成状态 usleep(1000); // 每1毫秒检查一次任务完成状态
} }
// 释放pthpool // 释放pthpool
if(pthread_pool) { if(pthread_pool) {
delete(pthread_pool); delete(pthread_pool);
...@@ -233,6 +234,84 @@ scclResult_t scclIpcSocket::scclIpcSocketSendFd(const int sendFd, int dst_rank) ...@@ -233,6 +234,84 @@ scclResult_t scclIpcSocket::scclIpcSocketSendFd(const int sendFd, int dst_rank)
return scclSuccess; return scclSuccess;
} }
/**
* @brief 通过IPC socket接收文件描述符
*
* 该函数使用recvmsg系统调用从socket接收文件描述符。函数会循环尝试接收,
* 直到成功或发生错误。接收到的文件描述符会通过参数recvFd返回。
*
* @param recvFd 用于存储接收到的文件描述符的指针
* @return scclResult_t 返回操作结果:
* - scclSuccess: 成功接收文件描述符
* - scclSystemError: 系统调用失败
* - scclInternalError: 操作被中止
*
* @note 函数会处理EAGAIN、EWOULDBLOCK和EINTR错误,其他错误会导致返回失败。
* 接收到的控制消息必须符合SOL_SOCKET级别和SCM_RIGHTS类型。
*/
scclResult_t scclIpcSocket::scclIpcSocketRecvFd(int* recvFd) {
// 初始化消息头结构体和iovec结构体
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
// 联合体用于保证控制数组的对齐要求
union {
struct cmsghdr cm;
char control[CMSG_SPACE(sizeof(int))];
} control_un;
struct cmsghdr* cmptr;
char dummy_buffer[1];
int ret;
// 设置消息头的控制信息部分
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof(control_un.control);
// 设置iovec结构体,用于指定要接收的数据
iov[0].iov_base = (void*)dummy_buffer;
iov[0].iov_len = sizeof(dummy_buffer);
// 将iovec结构体关联到消息头
msg.msg_iov = iov;
msg.msg_iovlen = 1;
// 循环接收消息,直到成功接收到数据
while((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
// 如果接收失败且错误不是EAGAIN, EWOULDBLOCK或EINTR,则记录警告并返回错误
if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
WARN("UDS: Receiving data over socket failed : %d", errno);
return scclSystemError;
}
// 如果设置了中止标志,则返回内部错误
if(handle->abortFlag && *handle->abortFlag)
return scclInternalError;
}
// 检查接收到的控制信息
if(((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
// 如果控制信息的级别或类型不正确,则记录警告并返回错误
if((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
WARN("UDS: Receiving data over socket failed");
return scclSystemError;
}
// 将接收到的文件描述符复制到recvFd
memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
} else {
// 如果没有接收到控制信息,则记录警告并返回错误
WARN("UDS: Receiving data over socket %s failed", handle->socketName);
return scclSystemError;
}
// 记录成功接收到文件描述符的信息
INFO(SCCL_LOG_BOOTSTRAP, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);
// 返回成功
return scclSuccess;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////
/** /**
* @brief 通过IPC socket接收文件描述符 * @brief 通过IPC socket接收文件描述符
* *
...@@ -380,52 +459,91 @@ scclResult_t scclIpcSocket::scclIpcSocketRecvData(void* buffer, size_t bufferLen ...@@ -380,52 +459,91 @@ scclResult_t scclIpcSocket::scclIpcSocketRecvData(void* buffer, size_t bufferLen
} }
} }
// 发送数据的方法 /**
* 通过IPC Socket发送数据并等待确认
*
* @param data 要发送的数据指针
* @param dataLen 要发送的数据长度
* @param dst_rank 目标rank号
* @return scclSuccess 发送成功,其他错误码表示失败
*
* 该函数会将数据分块发送(CHUNK_SIZE大小),每发送一个数据块后
* 会等待接收方返回ACK确认。如果收到非预期的ACK或发送/接收失败,
* 会立即返回错误。所有数据成功发送并收到正确ACK后返回成功。
*/
scclResult_t scclIpcSocket::scclIpcSocketSendDataWithAck(const void* data, size_t dataLen, int dst_rank) { scclResult_t scclIpcSocket::scclIpcSocketSendDataWithAck(const void* data, size_t dataLen, int dst_rank) {
scclResult_t result = scclIpcSocketSendData(data, dataLen, dst_rank); const char* dataPtr = static_cast<const char*>(data);
if(result != scclSuccess) { size_t bytesSent = 0;
return result;
}
printf("scclIpcSocketSendDataWithAck localRank=%d, dst_rank=%d\n", localRank, dst_rank);
// 等待接收方的ACK while(bytesSent < dataLen) {
char ack[ACK_SIZE]; size_t bytesToSend = std::min(CHUNK_SIZE, dataLen - bytesSent);
size_t receivedLen;
result = scclIpcSocketRecvData(ack, sizeof(ack), &receivedLen);
printf("scclIpcSocketSendDataWithAck recv ack=%s, localRank=%d, dst_rank=%d\n", ack, localRank, dst_rank);
// 检查是否是预期的ack // 发送数据块
char target_ack[ACK_SIZE]; scclResult_t sendResult = scclIpcSocketSendData(dataPtr + bytesSent, bytesToSend, dst_rank);
sprintf(target_ack, "ACK-%d", localRank); if(sendResult != scclSuccess) {
printf("scclIpcSocketSendDataWithAck 11 check recv ack=%s, target_ack=%s, localRank=%d, dst_rank=%d\n", ack, target_ack, localRank, dst_rank); return sendResult;
}
if(result != scclSuccess || strcmp(ack, target_ack) != 0) { // 等待接收方的ACK
printf("errrrrrr, result=%d, ack=%s, %s\n", result, ack, target_ack); char ack[ACK_SIZE];
return scclSystemError; size_t receivedLen;
scclResult_t recvResult = scclIpcSocketRecvData(ack, sizeof(ack), &receivedLen);
if(recvResult != scclSuccess) {
return recvResult;
}
// 检查是否是预期的ack
char target_ack[ACK_SIZE];
sprintf(target_ack, "ACK-%d", localRank);
if(strcmp(ack, target_ack) != 0) {
WARN("UDS: Received unexpected ACK: %s", ack);
return scclSystemError;
}
bytesSent += bytesToSend;
} }
printf("scclIpcSocketSendDataWithAck 22 check recv ack=%s, target_ack=%s, localRank=%d, dst_rank=%d\n", ack, target_ack, localRank, dst_rank);
INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zu bytes of data with ACK through UDS socket", dataLen);
return scclSuccess; return scclSuccess;
} }
// 接收数据的方法 /**
* 通过IPC Socket接收数据并发送ACK确认
*
* @param buffer 接收数据缓冲区指针
* @param bufferLen 缓冲区总长度
* @param receivedLen 实际接收到的数据长度(输出参数)
* @param src_rank 发送方rank号
* @return scclSuccess表示成功,其他错误码表示失败
*
* @note 采用分块接收机制,每接收一个数据块都会发送ACK确认
* 接收完成后会记录日志信息
*/
scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndSendAck(void* buffer, size_t bufferLen, size_t* receivedLen, int src_rank) { scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndSendAck(void* buffer, size_t bufferLen, size_t* receivedLen, int src_rank) {
scclResult_t result = scclIpcSocketRecvData(buffer, bufferLen, receivedLen); char* bufferPtr = static_cast<char*>(buffer);
if(result != scclSuccess) { size_t bytesReceived = 0;
return result;
} while(bytesReceived < bufferLen) {
printf("scclIpcSocketRecvDataAndSendAck localRank=%d, src_rank=%d\n", localRank, src_rank); size_t bytesToReceive = std::min(CHUNK_SIZE, bufferLen - bytesReceived);
// 发送ACK给发送方 // 接收数据块
char ack[ACK_SIZE]; scclResult_t recvResult = scclIpcSocketRecvData(bufferPtr + bytesReceived, bytesToReceive, receivedLen);
sprintf(ack, "ACK-%d", src_rank); if(recvResult != scclSuccess) {
printf("scclIpcSocketRecvDataAndSendAck localRank=%d, src_rank=%d, ack=%s\n", localRank, src_rank, ack); return recvResult;
result = scclIpcSocketSendData(ack, strlen(ack), /* 发送方的rank号 */ src_rank); }
if(result != scclSuccess) {
return result; // 发送ACK给发送方
char ack[ACK_SIZE];
sprintf(ack, "ACK-%d", src_rank);
scclResult_t sendResult = scclIpcSocketSendData(ack, strlen(ack), src_rank);
if(sendResult != scclSuccess) {
return sendResult;
}
bytesReceived += *receivedLen;
} }
printf("scclIpcSocketRecvDataAndSendAck send localRank=%d, src_rank=%d, ack=%s\n", localRank, src_rank, ack);
INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %zu bytes of data and sent ACK through UDS socket", bufferLen);
return scclSuccess; return scclSuccess;
} }
...@@ -438,14 +556,13 @@ scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndSendAck(void* buffer, size_t ...@@ -438,14 +556,13 @@ scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndSendAck(void* buffer, size_t
* @param sendData 发送数据缓冲区指针 * @param sendData 发送数据缓冲区指针
* @param recvData 接收数据缓冲区指针 * @param recvData 接收数据缓冲区指针
* @param dataLen 每个节点的数据长度(字节) * @param dataLen 每个节点的数据长度(字节)
* @param wait 是否等待所有通信完成
* @return scclResult_t 返回操作结果(scclSuccess表示成功) * @return scclResult_t 返回操作结果(scclSuccess表示成功)
* *
* @note 1. 会跳过本地rank的数据传输 * @note 1. 会跳过本地rank的数据传输
* 2. 数据包格式: [发送rank(int)][数据] * 2. 数据包格式: [发送rank(int)][数据]
* 3. 接收缓冲区需要预先分配足够空间(大小=nlocalRanks*dataLen) * 3. 接收缓冲区需要预先分配足够空间(大小=nlocalRanks*dataLen)
*/ */
scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen, bool wait) { scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen) {
if(pthread_pool == nullptr || nlocalRanks <= 0) { if(pthread_pool == nullptr || nlocalRanks <= 0) {
WARN("scclIpcSocket init error!"); WARN("scclIpcSocket init error!");
return scclInternalError; return scclInternalError;
...@@ -498,11 +615,9 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* r ...@@ -498,11 +615,9 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* r
} }
} }
if(wait) { // 等待所有任务完成
// 等待所有任务完成 while(!pthread_pool->allTasksCompleted()) {
while(!pthread_pool->allTasksCompleted()) { usleep(1000); // 每1毫秒检查一次任务完成状态
usleep(1000); // 每1毫秒检查一次任务完成状态
}
} }
return scclSuccess; return scclSuccess;
...@@ -516,14 +631,13 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* r ...@@ -516,14 +631,13 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* r
* @param sendData 发送数据缓冲区指针 * @param sendData 发送数据缓冲区指针
* @param recvData 接收数据缓冲区指针 * @param recvData 接收数据缓冲区指针
* @param dataLen 每个进程发送/接收的数据长度 * @param dataLen 每个进程发送/接收的数据长度
* @param wait 是否等待所有通信任务完成
* @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回错误码 * @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回错误码
* *
* @note 1. 函数会先将本地数据复制到接收缓冲区对应位置 * @note 1. 函数会先将本地数据复制到接收缓冲区对应位置
* 2. 使用线程池并行处理与其他进程的通信任务 * 2. 使用线程池并行处理与其他进程的通信任务
* 3. 当wait为true时会阻塞等待所有通信完成 * 3. 当wait为true时会阻塞等待所有通信完成
*/ */
scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, void* recvData, size_t dataLen, bool wait) { scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, void* recvData, size_t dataLen) {
if(pthread_pool == nullptr || nlocalRanks <= 0) { if(pthread_pool == nullptr || nlocalRanks <= 0) {
WARN("scclIpcSocket init error!"); WARN("scclIpcSocket init error!");
return scclInternalError; return scclInternalError;
...@@ -546,11 +660,9 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, voi ...@@ -546,11 +660,9 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, voi
} }
} }
if(wait) { // 等待所有任务完成
// 等待所有任务完成 while(!pthread_pool->allTasksCompleted()) {
while(!pthread_pool->allTasksCompleted()) { usleep(1000); // 每1毫秒检查一次任务完成状态
usleep(1000); // 每1毫秒检查一次任务完成状态
}
} }
return scclSuccess; return scclSuccess;
...@@ -566,17 +678,13 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, voi ...@@ -566,17 +678,13 @@ scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, voi
* @param recvData 接收数据缓冲区指针(非根进程使用) * @param recvData 接收数据缓冲区指针(非根进程使用)
* @param dataLen 数据长度(字节) * @param dataLen 数据长度(字节)
* @param root 根进程的rank值 * @param root 根进程的rank值
* @param wait 是否等待所有通信操作完成
* *
* @return scclResult_t 返回操作结果状态码 * @return scclResult_t 返回操作结果状态码
* - scclSuccess: 操作成功 * - scclSuccess: 操作成功
* - scclInternalError: IPC Socket未初始化或本地rank数无效 * - scclInternalError: IPC Socket未初始化或本地rank数无效
* - scclInvalidArgument: 根进程rank值无效 * - scclInvalidArgument: 根进程rank值无效
*/ */
scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, int root, bool wait) { scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, int root) {
pthread_pool->allTasksCompleted();
if(pthread_pool == nullptr || nlocalRanks <= 0) { if(pthread_pool == nullptr || nlocalRanks <= 0) {
WARN("scclIpcSocket init error!"); WARN("scclIpcSocket init error!");
return scclInternalError; return scclInternalError;
...@@ -586,34 +694,29 @@ scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, i ...@@ -586,34 +694,29 @@ scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, i
return scclInvalidArgument; return scclInvalidArgument;
} }
// if(localRank == root) { if(localRank == root) {
// // 根进程:发送数据给所有其他进程 // 根进程:发送数据给所有其他进程
// for(int i = 0; i < nlocalRanks; ++i) { for(int i = 0; i < nlocalRanks; ++i) {
// if(i != root) { if(i != root) {
// // 使用 std::bind 绑定 scclIpcSocketSendDataWithAck 方法和参数 // 使用 std::bind 绑定 scclIpcSocketSendDataWithAck 方法和参数
// auto sendTask = std::bind(&scclIpcSocket::scclIpcSocketSendDataWithAck, this, data, dataLen, i); auto sendTask = std::bind(&scclIpcSocket::scclIpcSocketSendDataWithAck, this, data, dataLen, i);
// // 将绑定后的函数对象添加到线程池的任务队列中 // 将绑定后的函数对象添加到线程池的任务队列中
// pthread_pool->enqueue(sendTask); pthread_pool->enqueue(sendTask);
}
// printf("send root: %d, i=%d\n", root, i); }
// } } else {
// } size_t receivedLen;
// } else { scclResult_t result = scclIpcSocketRecvDataAndSendAck(data, dataLen, &receivedLen, root);
// size_t receivedLen; if(result != scclSuccess) {
// scclResult_t result = scclIpcSocketRecvDataAndSendAck(data, dataLen, &receivedLen, root); return result;
// if(result != scclSuccess) {
// return result;
// }
// printf("recv from root: localRank=%d\n", localRank);
// }
if(wait) {
// 等待所有任务完成
while(!pthread_pool->allTasksCompleted()) {
usleep(1000); // 每1毫秒检查一次任务完成状态
} }
} }
// 等待所有任务完成
while(!pthread_pool->allTasksCompleted()) {
usleep(1000); // 每1毫秒检查一次任务完成状态
}
return scclSuccess; return scclSuccess;
} }
......
...@@ -70,13 +70,13 @@ public: ...@@ -70,13 +70,13 @@ public:
////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////
// local rank内的allgather操作。保证接收顺序 // local rank内的allgather操作。保证接收顺序
scclResult_t scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen, bool wait = true); scclResult_t scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen);
// local rank内的allgather操作。为了性能,不保证接收顺序,所以发送的信息中需要添加进程ID // local rank内的allgather操作。为了性能,不保证接收顺序,所以发送的信息中需要添加进程ID
scclResult_t scclIpcSocketAllgatherSync(const void* sendData, void* recvData, size_t dataLen, bool wait = true); scclResult_t scclIpcSocketAllgatherSync(const void* sendData, void* recvData, size_t dataLen);
// local rank内的broadcast操作 // local rank内的broadcast操作
scclResult_t scclIpcSocketBroadcast(void* data, size_t dataLen, int root, bool wait = true); scclResult_t scclIpcSocketBroadcast(void* data, size_t dataLen, int root);
private: private:
// 初始化IPC套接字 // 初始化IPC套接字
...@@ -101,8 +101,12 @@ private: ...@@ -101,8 +101,12 @@ private:
// 线程池指针 // 线程池指针
ThreadPool* pthread_pool = nullptr; ThreadPool* pthread_pool = nullptr;
// 设置超时时间为无限长 // 设置超时时间为无限长
int timeoutMs = -1; int timeoutMs = -1;
// 各种数据大小的固定值
static constexpr int ACK_SIZE = 8; static constexpr int ACK_SIZE = 8;
// 假设 CHUNK_SIZE 是一个合适的块大小,例如 64KB
static constexpr size_t CHUNK_SIZE = 64 * 1024;
}; };
} // namespace ipc_socket } // namespace ipc_socket
......
...@@ -320,19 +320,16 @@ scclResult_t Bootstrap::init(struct BootstrapComm* bootstrap_comm) { ...@@ -320,19 +320,16 @@ scclResult_t Bootstrap::init(struct BootstrapComm* bootstrap_comm) {
scclSocketAddress_t localSocketAddr = bootstrapNet::getLocalSocketAddr(); scclSocketAddress_t localSocketAddr = bootstrapNet::getLocalSocketAddr();
// 设置基础信息 // 设置基础信息
struct BootstrapNodeBasic node_basic{rank, nRanks, hostHash, localSocketAddr}; struct BootstrapNodeBasic node_basic{rank, nRanks, hostHash, localSocketAddr};
printf("Bootstrap::init 111 rank=%d/%d, localRank=%d/%d\n", rank, nRanks, localRank, nLocalRanks);
// -------------------------- 2.设置0号rank搜集的CPU信息和localRank信息 ----------------------------------- // // -------------------------- 2.设置0号rank搜集的CPU信息和localRank信息 ----------------------------------- //
// 创建根节点的数据收集 // 创建根节点的数据收集
struct BootstrapNodeBasic* all_node_basic; std::vector<struct BootstrapNodeBasic> all_node_basic;
SCCLCHECK(scclCalloc(&all_node_basic, nRanks)); // 节点间用于传输数据的基础信息 all_node_basic.reserve(nRanks);
SCCLCHECK(bootstrapRootGatherAndBroadcast(&node_basic, all_node_basic)); SCCLCHECK(bootstrapRootGatherAndBroadcast(&node_basic, all_node_basic.data()));
printf("Bootstrap::init 222 rank=%d/%d, localRank=%d/%d\n", rank, nRanks, localRank, nLocalRanks);
// -------------------------- 3.设置本地localRank的BootstrapComm信息 ----------------------------------- // // -------------------------- 3.设置本地localRank的BootstrapComm信息 ----------------------------------- //
// 初始化BootstrapComm类 // 初始化BootstrapComm类
bootstrap_comm->init(rank, nRanks, localRank, nLocalRanks); bootstrap_comm->init(rank, nRanks, localRank, nLocalRanks);
printf("Bootstrap::init 333 rank=%d/%d, localRank=%d/%d\n", rank, nRanks, localRank, nLocalRanks);
if(CPU_COUNT(&bootstrap_comm->cpuAffinity)) { if(CPU_COUNT(&bootstrap_comm->cpuAffinity)) {
sched_setaffinity(0, sizeof(cpu_set_t), &bootstrap_comm->cpuAffinity); sched_setaffinity(0, sizeof(cpu_set_t), &bootstrap_comm->cpuAffinity);
...@@ -371,57 +368,7 @@ scclResult_t Bootstrap::init(struct BootstrapComm* bootstrap_comm) { ...@@ -371,57 +368,7 @@ scclResult_t Bootstrap::init(struct BootstrapComm* bootstrap_comm) {
// -------------------------- 4.BootstrapComm信息的allgather ----------------------------------- // // -------------------------- 4.BootstrapComm信息的allgather ----------------------------------- //
// bootstrap_comm = new BootstrapComm(rank_info->nRanks); // bootstrapAllGather(bootstrap_comm->node_info);
// auto node_info_vec = bootstrap_comm->node_info_set->node_info_vec;
// struct scclNodeInfo* local_node_info = node_info_vec[0];
// constexpr int handle_size = sizeof(struct BootstrapHandle);
// auto handleBufferChr = reinterpret_cast<const char*>(handleBuffer);
// for(int i = 0; i < rank_info->nRanks; ++i) {
// auto temp_bootstrap_handle = deserializeBootstrapData<BootstrapHandle>(&handleBufferChr[i * handle_size]);
// int rank = temp_bootstrap_handle.rank;
// uint64_t hostHash = temp_bootstrap_handle.hostHash;
// // scclSocketAddress_t addr; // 地址,用于网络通信
// printf("bootstrapInit rank=%d, i=%d, hostHash=%lu\n", rank, i, hostHash);
// }
// printUniqueInfo(bootstrap_comm->unique_info);
// // 如果已经初始化,直接返回成功
// if(asm_ops::ld_acquire_sys_global(&initialized))
// return scclSuccess;
// // 加锁以确保初始化过程的线程安全
// pthread_mutex_lock(&bootstrapNetLock);
// // 如果尚未初始化,进行初始化操作
// if(!initialized) {
// // -------------------------- 1.设置各种属性 ----------------------------------- //
// // 获取CPU亲和性
// sched_getaffinity(0, sizeof(cpu_set_t), &bootstrap_comm->cpuAffinity);
// bootstrap_comm->nRanks = rank_info->nRanks;
// uint32_t devices_num;
// SCCLCHECK(rocm_smi_init()); // 初始化ROCM SMI库
// SCCLCHECK(rocm_smi_getNumDevice(&devices_num)); // 获取设备数量
// LTCHECK(devices_num, 0); // 检查设备数量是否大于0
// bootstrap_comm->deviceCnt = static_cast<int>(devices_num); // 将设备数量转换为int并赋值给的deviceCnt
// printf("devices_num=%s\n", bootstrap_comm->deviceCnt);
// // SCCLCHECK(getIpcSocketAddr(&handle->peerIpcAddr));
// #if 0
// // char line[100];
// // sprintf(line, "pos 55: rank=%d", rank_info->rank);
// // SCCLCHECK(net::printSocketAddr(&handle->addr, line));
// #endif
// // bootstrapAllGather(bootstrap_comm->node_info);
// // 设置初始化完成标志
// asm_ops::st_release_sys_global(&initialized, true);
// }
// // 解锁
// pthread_mutex_unlock(&bootstrapNetLock);
// 设置初始化标志 // 设置初始化标志
asm_ops::st_release_sys_global(&socketInitDone, true); asm_ops::st_release_sys_global(&socketInitDone, true);
...@@ -475,11 +422,10 @@ scclResult_t Bootstrap::bootstrapRootGatherAndBroadcast(void* send_data, void* r ...@@ -475,11 +422,10 @@ scclResult_t Bootstrap::bootstrapRootGatherAndBroadcast(void* send_data, void* r
net::net_socket::scclSocketAcceptManager accept_manager(my_listen_sock); net::net_socket::scclSocketAcceptManager accept_manager(my_listen_sock);
SCCLCHECK(bootstrapNet::bootstrapNetRecv(accept_manager.getSocket(), recv_data_basic, recv_data_basic_size)); SCCLCHECK(bootstrapNet::bootstrapNetRecv(accept_manager.getSocket(), recv_data_basic, recv_data_basic_size));
} }
printf("Bootstrap::bootstrapRootGatherAndBroadcast 444 rank=%d/%d, localRank=%d/%d\n", rank, nRanks, localRank, nLocalRanks);
// ------------- 5.nLocalRanks==0时,将所有rank的ip数据广播给节点内其他rank ------------- // // ------------- 5.nLocalRanks==0时,将所有rank的ip数据广播给节点内其他rank ------------- //
ipcsocket = new scclIpcSocket_t(localRank, nLocalRanks, /*hash*/ root_handle->magic); ipcsocket = new scclIpcSocket_t(localRank, nLocalRanks, /*hash*/ root_handle->magic);
ipcsocket->scclIpcSocketBroadcast(recv_data_basic, recv_data_basic_size, 0, /*wait*/ true); ipcsocket->scclIpcSocketBroadcast(recv_data_basic, recv_data_basic_size, /*localRank root*/ 0);
return scclSuccess; return scclSuccess;
} }
......
...@@ -9,18 +9,20 @@ static constexpr int THREADS_POOL_MAX_SIZE = 128; ...@@ -9,18 +9,20 @@ static constexpr int THREADS_POOL_MAX_SIZE = 128;
/** /**
* @brief 线程池构造函数 * @brief 线程池构造函数
* *
* 初始化线程池,创建指定数量的工作线程。 * 初始化线程池,创建指定数量的工作线程并设置CPU亲和性
* *
* @param threads_num 线程池中初始线程数量,不超过THREADS_POOL_MAX_SIZE限制 * @param threads_num 线程池中线程的数量
* @param cpu_cord_offset CPU亲和性设置的起始偏移量(跳过核心0)
* *
* @note 会初始化互斥锁和条件变量,并启动工作线程执行ThreadPool::run函数 * @note 线程数量会被限制在THREADS_POOL_MAX_SIZE以内
* @note 每个工作线程会被绑定到不同的CPU核心,从cpu_cord_offset开始
*/ */
ThreadPool::ThreadPool(size_t threads_num, int cpu_cord_offset) : stop(false), active_tasks(0) { ThreadPool::ThreadPool(size_t threads_num, int cpu_cord_offset) : stop(false), active_tasks(0) {
threads_num = min(THREADS_POOL_MAX_SIZE, threads_num); threads_num = min(THREADS_POOL_MAX_SIZE, threads_num);
pthread_mutex_init(&queue_mutex, nullptr); pthread_mutex_init(&queue_mutex, nullptr);
pthread_cond_init(&condition, nullptr); pthread_cond_init(&condition, nullptr);
// printf("ThreadPool 构造函数"); workers.reserve(threads_num);
for(size_t i = 0; i < threads_num; ++i) { for(size_t i = 0; i < threads_num; ++i) {
pthread_t worker; pthread_t worker;
...@@ -87,8 +89,6 @@ void* ThreadPool::run(void* arg) { ...@@ -87,8 +89,6 @@ void* ThreadPool::run(void* arg) {
task(); // 执行任务 task(); // 执行任务
{ {
pthread_mutex_lock(&pool->queue_mutex); pthread_mutex_lock(&pool->queue_mutex);
printf("ThreadPool active_tasks--");
pool->active_tasks--; // 任务完成减少活动任务计数 pool->active_tasks--; // 任务完成减少活动任务计数
pthread_mutex_unlock(&pool->queue_mutex); pthread_mutex_unlock(&pool->queue_mutex);
} }
...@@ -103,7 +103,6 @@ void* ThreadPool::run(void* arg) { ...@@ -103,7 +103,6 @@ void* ThreadPool::run(void* arg) {
*/ */
bool ThreadPool::allTasksCompleted() { bool ThreadPool::allTasksCompleted() {
pthread_mutex_lock(&queue_mutex); pthread_mutex_lock(&queue_mutex);
printf("active_tasks: %d, tasks.size(): %lu\n", active_tasks, tasks.size());
bool completed = (active_tasks == 0) && tasks.empty(); bool completed = (active_tasks == 0) && tasks.empty();
pthread_mutex_unlock(&queue_mutex); pthread_mutex_unlock(&queue_mutex);
return completed; return completed;
......
...@@ -25,11 +25,8 @@ public: ...@@ -25,11 +25,8 @@ public:
std::future<return_type> res = task->get_future(); std::future<return_type> res = task->get_future();
{ {
pthread_mutex_lock(&queue_mutex); pthread_mutex_lock(&queue_mutex);
tasks.push([task]() { (*task)(); }); tasks.push([task]() { (*task)(); });
active_tasks++; // 新任务增加活动任务计数 active_tasks++; // 新任务增加活动任务计数
// printf("ThreadPool active_tasks++");
pthread_mutex_unlock(&queue_mutex); pthread_mutex_unlock(&queue_mutex);
pthread_cond_signal(&condition); pthread_cond_signal(&condition);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment