ipc_socket.cpp 22.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <pthread.h>
#include <stdlib.h>
#include <poll.h>
#include <limits.h>
#include <fcntl.h>
#include <thread> // 为了使用 std::this_thread::sleep_for
#include "ipc_socket.h"

namespace sccl {
namespace hardware {
namespace net {

namespace ipc_socket {

//////////////////////////////////////// scclIpcSocket调用的函数 ////////////////////////////////////////
16
17
scclIpcSocket::scclIpcSocket(int localRank, int nlocalRanks, uint64_t hash, volatile uint32_t* abortFlag)
    : localRank(localRank), nlocalRanks(nlocalRanks), ipc_hash(hash) {
18
    scclResult_t res;
19
20
21
22
    // 初始化handle
    handle                = new struct scclIpcSocketHandle();
    handle->fd            = -1;
    handle->socketName[0] = '\0';
23

24
25
26
27
28
    // 设置线程池
    if(nlocalRanks > 0) {
        pthread_pool = new ThreadPool(nlocalRanks * 2); // 其中一半用于发送一半,用于接收
    } else {
        goto failure;
29
30
31
32
33
34
35
36
37
38
39
    }

    SCCLCHECKGOTO(scclIpcSocketInit(abortFlag), res, failure);
    return;

failure:
    WARN("scclIpcSocket init failed");
    return;
}

scclIpcSocket::~scclIpcSocket() {
40
41
42
43
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
    }
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    // 释放pthpool
    if(pthread_pool) {
        delete(pthread_pool);
    }

    // 释放handle
    if(handle->socketName[0] != '\0') {
        unlink(handle->socketName);
    }
    if(handle->fd >= 0) {
        close(handle->fd);
    }
    delete(handle);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

scclResult_t scclIpcSocket::scclIpcSocketInit(volatile uint32_t* abortFlag) {
    // 中间变量
    int fd = -1;
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];

    // 创建Unix域套接字
    // af是本机IP地址类型,一般有PF_INET或者AF_INET(IPv4互联网协议族),还有PF_INET6(IPv6互联网协议族)等,但是一般用IPv4。
    // type有两种SOCK_STREAM 和SOCK_DGRAM分别对应tcp和udp协议,区别是用不用建立连接。
    if((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
        WARN("UDS: Socket creation error : %d", errno);
        return scclSystemError;
    }

    // 将cliaddr结构体清零,确保没有残留数据
    bzero(&my_cliaddr, sizeof(my_cliaddr));
    my_cliaddr.sun_family = AF_UNIX;

    // 为套接字创建唯一名称
79
80
    int len;
    SCCLCHECK(getScclIpcSocknameStr(localRank, ipc_hash, temp_addr, &len));
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Creating socket %s", temp_addr);

    // 设置套接字路径
    strncpy(my_cliaddr.sun_path, temp_addr, len);
    my_cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧

    // 绑定套接字
    if(bind(fd, (struct sockaddr*)&my_cliaddr, sizeof(my_cliaddr)) < 0) {
        WARN("UDS: Binding to socket %s failed : %d", temp_addr, errno);
        close(fd);
        return scclSystemError;
    }

    // 设置handle的成员变量
    handle->fd = fd;
    strcpy(handle->socketName, temp_addr);

    // 设置中止标志
    handle->abortFlag = abortFlag;
    // 将套接字标记为非阻塞
    if(handle->abortFlag) {
        int flags;
        EQCHECK(flags = fcntl(fd, F_GETFL), -1);
        SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
    }

    return scclSuccess;
}

/**
 * 设置中止标志并更新socket的非阻塞模式
 *
 * @param flag 指向中止标志的指针。如果非空,将socket设为非阻塞模式;
 *             如果为空,则恢复为阻塞模式。
 * @note 该函数仅在handle有效时执行操作
 */
scclResult_t scclIpcSocket::setAbortFlag(volatile uint32_t* flag) {
    if(handle) {
        handle->abortFlag = flag;
        if(flag) {
            int flags;
            EQCHECK(flags = fcntl(handle->fd, F_GETFL), -1);
            SYSCHECK(fcntl(handle->fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
        } else {
            int flags;
            EQCHECK(flags = fcntl(handle->fd, F_GETFL), -1);
            SYSCHECK(fcntl(handle->fd, F_SETFL, flags & ~O_NONBLOCK), "fcntl");
        }
    }
    return scclSuccess;
}

// 获取 abortFlag 的函数
volatile uint32_t* scclIpcSocket::getAbortFlag() const { return handle ? handle->abortFlag : nullptr; }

/**
 * 设置IPC套接字的超时时间
 *
 * @param timeout_ms 超时时间(毫秒)
 * @return 成功返回scclSuccess
 */
scclResult_t scclIpcSocket::setTimeout(int timeout_ms) {
    timeoutMs = timeout_ms;
    return scclSuccess;
}

ThreadPool* scclIpcSocket::getPthreadPool() { return pthread_pool; }

//////////////////////////////////////////////////////////////////////////////////////////////////////
/**
 * @brief 通过Unix域套接字发送文件描述符
 *
 * @param sendFd 要发送的文件描述符
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如地址过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误
 *
 * @note 使用Linux抽象套接字技巧(将sun_path[0]置为'\0')
 *       通过SCM_RIGHTS机制发送文件描述符
 *       函数会循环尝试发送直到成功或遇到错误
 */
scclResult_t scclIpcSocket::scclIpcSocketSendFd(const int sendFd, int dst_rank) {
    // 创建一个临时地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
    // 格式化地址字符串
168
169
170
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    // 记录发送文件描述符的信息
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Sending fd %d to UDS socket %s/fd:%d", sendFd, temp_addr, handle->fd);

    // 初始化消息头结构体和iovec结构体
    struct msghdr msg;
    struct iovec iov[1];

    // 联合体用于保证控制数组的对齐要求
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    struct sockaddr_un cliaddr;

    // 构造客户端地址以发送共享句柄
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧

    // 设置消息头的控制信息部分
    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    cmptr             = CMSG_FIRSTHDR(&msg);
    cmptr->cmsg_len   = CMSG_LEN(sizeof(int));
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type  = SCM_RIGHTS;

    // 将要发送的文件描述符复制到控制信息中
    memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd));

    // 设置消息头的地址信息部分
    msg.msg_name    = (void*)&cliaddr;
    msg.msg_namelen = sizeof(struct sockaddr_un);

    // 设置iovec结构体,用于指定要发送的数据
    iov[0].iov_base = (void*)"";
    iov[0].iov_len  = 1;

    // 将iovec结构体关联到消息头
    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    // 初始化消息标志
    msg.msg_flags = 0;

    ssize_t sendResult;
    // 循环发送消息,直到成功发送数据
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        // 如果发送失败且错误不是EAGAIN, EWOULDBLOCK或EINTR,则记录警告并返回错误
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Sending data over socket %s failed : %d", temp_addr, errno);
            return scclSystemError;
        }
        // 如果设置了中止标志,则返回内部错误
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    return scclSuccess;
}

/**
 * @brief 通过IPC socket接收文件描述符
 *
 * 该函数使用recvmsg系统调用从socket接收文件描述符。函数会循环尝试接收,
 * 直到成功或发生错误。接收到的文件描述符会通过参数recvFd返回。
 *
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用失败
 *         - scclInternalError: 操作被中止
 *
 * @note 函数会处理EAGAIN、EWOULDBLOCK和EINTR错误,其他错误会导致返回失败。
 *       接收到的控制消息必须符合SOL_SOCKET级别和SCM_RIGHTS类型。
 */
scclResult_t scclIpcSocket::scclIpcSocketSendData(const void* data, size_t dataLen, int dst_rank) {
    // 构造目标地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
254
255
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));
256
257
258
259
260
261
262
263
264

    // 设置消息结构体
    struct msghdr msg;
    struct iovec iov[1];
    struct sockaddr_un cliaddr;
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧
265
266
267
268
269
    msg.msg_name        = (void*)&cliaddr;
    msg.msg_namelen     = sizeof(cliaddr);
    msg.msg_control     = NULL;
    msg.msg_controllen  = 0;
    msg.msg_flags       = 0;
270

271
272
273
274
    iov[0].iov_base = (void*)data;
    iov[0].iov_len  = dataLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

    // 使用 poll 等待 socket 可写
    struct pollfd pfd;
    pfd.fd     = handle->fd;
    pfd.events = POLLOUT;

    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);
        } else {
            WARN("UDS: Error occurred while polling socket %s for writability : %d", temp_addr, errno);
        }
        return scclSystemError;
    }

    ssize_t sendResult;
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Error occurred while sending data through socket %s : %d", temp_addr, errno);
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;

        pollResult = poll(&pfd, 1, timeoutMs);
        if(pollResult <= 0) {
            if(pollResult == 0) {
                WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);
            } else {
                WARN("UDS: Error occurred while polling socket %s for writability : %d", temp_addr, errno);
            }
            return scclSystemError;
        }
    }

311
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zd bytes of data through UDS socket %s", sendResult, temp_addr);
312
313
314
315
    return scclSuccess;
}

/**
316
 * @brief 通过IPC套接字发送数据到指定目标rank
317
 *
318
319
320
321
322
323
324
 * @param data 要发送的数据指针
 * @param dataLen 要发送的数据长度
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如套接字名称过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误(如poll超时或sendmsg失败)
325
 *
326
327
 * @note 使用Linux抽象套接字技术,通过poll机制确保套接字可写后再发送数据
 *       支持EAGAIN/EWOULDBLOCK/EINTR错误重试机制
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvData(void* buffer, size_t bufferLen, size_t* receivedLen) {
    // 设置消息结构体
    struct msghdr msg = {0};
    struct iovec iov[1];
    iov[0].iov_base = buffer;
    iov[0].iov_len  = bufferLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;

    // 使用 poll 等待 socket 可读
    struct pollfd pfd;
    pfd.fd     = handle->fd;
    pfd.events = POLLIN;

    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);
        } else {
            WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
        }
        return scclSystemError;
    }

    int ret;
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    while(true) {
        ret = recvmsg(handle->fd, &msg, 0);
        if(ret > 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %d bytes of data from socket %s", ret, handle->socketName);
            *receivedLen = ret;
            return scclSuccess;
        } else if(ret == 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Connection closed by peer on socket %s", handle->socketName);
            *receivedLen = 0;
            return scclSuccess;
        } else {
            if(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                pollResult = poll(&pfd, 1, timeoutMs);
                if(pollResult <= 0) {
                    if(pollResult == 0) {
                        WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);
                    } else {
                        WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
                    }
                    return scclSystemError;
                }
375
            } else {
376
377
                WARN("UDS: Error occurred while receiving data through socket %s : %d", handle->socketName, errno);
                return scclSystemError;
378
379
380
381
382
            }
        }
    }
}

383
384
385
386
387
// 发送数据的方法
scclResult_t scclIpcSocket::scclIpcSocketSendDataWithAck(const void* data, size_t dataLen, int dst_rank) {
    scclResult_t result = scclIpcSocketSendData(data, dataLen, dst_rank);
    if(result != scclSuccess) {
        return result;
388
    }
389
    printf("scclIpcSocketSendDataWithAck localRank=%d, dst_rank=%d\n", localRank, dst_rank);
390

391
392
393
394
395
396
397
398
399
400
401
402
403
404
    // 等待接收方的ACK
    char ack[ACK_SIZE];
    size_t receivedLen;
    result = scclIpcSocketRecvData(ack, sizeof(ack), &receivedLen);
    printf("scclIpcSocketSendDataWithAck recv ack=%s, localRank=%d, dst_rank=%d\n", ack, localRank, dst_rank);

    // 检查是否是预期的ack
    char target_ack[ACK_SIZE];
    sprintf(target_ack, "ACK-%d", localRank);
    printf("scclIpcSocketSendDataWithAck 11 check recv ack=%s, target_ack=%s, localRank=%d, dst_rank=%d\n", ack, target_ack, localRank, dst_rank);

    if(result != scclSuccess || strcmp(ack, target_ack) != 0) {
        printf("errrrrrr, result=%d, ack=%s, %s\n", result, ack, target_ack);
        return scclSystemError;
405
    }
406
407
    printf("scclIpcSocketSendDataWithAck 22 check recv ack=%s, target_ack=%s, localRank=%d, dst_rank=%d\n", ack, target_ack, localRank, dst_rank);

408
409
    return scclSuccess;
}
410
411
412
413
414
415

// 接收数据的方法
scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndSendAck(void* buffer, size_t bufferLen, size_t* receivedLen, int src_rank) {
    scclResult_t result = scclIpcSocketRecvData(buffer, bufferLen, receivedLen);
    if(result != scclSuccess) {
        return result;
416
    }
417
418
419
420
421
422
423
424
425
    printf("scclIpcSocketRecvDataAndSendAck localRank=%d, src_rank=%d\n", localRank, src_rank);

    // 发送ACK给发送方
    char ack[ACK_SIZE];
    sprintf(ack, "ACK-%d", src_rank);
    printf("scclIpcSocketRecvDataAndSendAck localRank=%d, src_rank=%d, ack=%s\n", localRank, src_rank, ack);
    result = scclIpcSocketSendData(ack, strlen(ack), /* 发送方的rank号 */ src_rank);
    if(result != scclSuccess) {
        return result;
426
    }
427
428
429
    printf("scclIpcSocketRecvDataAndSendAck send localRank=%d, src_rank=%d, ack=%s\n", localRank, src_rank, ack);

    return scclSuccess;
430
431
}

432
/////////////////////////////////////////////////////////////////////////////////////////////////////
433
434
435
436
437
438
439
440
441
442
443
444
445
/**
 * @brief 使用IPC套接字实现Allgather操作
 *
 * 该函数通过线程池并行发送和接收数据,实现多节点间的Allgather集合通信。
 *
 * @param sendData 发送数据缓冲区指针
 * @param recvData 接收数据缓冲区指针
 * @param dataLen 每个节点的数据长度(字节)
 * @param wait 是否等待所有通信完成
 * @return scclResult_t 返回操作结果(scclSuccess表示成功)
 *
 * @note 1. 会跳过本地rank的数据传输
 *       2. 数据包格式: [发送rank(int)][数据]
446
 *       3. 接收缓冲区需要预先分配足够空间(大小=nlocalRanks*dataLen)
447
448
 */
scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen, bool wait) {
449
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
450
451
452
453
454
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }

    // 采用线程池发送和接收数据
455
    for(int i = 0; i < nlocalRanks; ++i) {
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        if(i != localRank) {
            auto sendTask = [this, sendData, dataLen, i]() {
                // 计算 DataPackage 的总大小
                size_t packageSize = sizeof(int) + dataLen;
                char* buffer       = new char[packageSize];

                // 将 rank 信息和数据一起拷贝到 buffer 中
                int* rankPtr = reinterpret_cast<int*>(buffer);
                *rankPtr     = localRank;

                char* dataPtr = buffer + sizeof(int);
                memcpy(dataPtr, sendData, dataLen);

                // 一次性发送 rank 信息和数据
                scclIpcSocketSendData(buffer, packageSize, i);

                delete[] buffer;
            };
474
            pthread_pool->enqueue(sendTask);
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493

            auto recvTask = [this, recvData, dataLen, i]() {
                // 准备接收缓冲区
                size_t packageSize = sizeof(int) + dataLen;
                char* buffer       = new char[packageSize];
                size_t receivedLen;

                // 一次性接收 rank 信息和数据
                scclIpcSocketRecvData(buffer, packageSize, &receivedLen);

                // 从 buffer 中提取 rank 信息和数据
                int* rankPtr   = reinterpret_cast<int*>(buffer);
                int senderRank = *rankPtr;

                char* dataPtr = buffer + sizeof(int);
                memcpy(static_cast<char*>(recvData) + senderRank * dataLen, dataPtr, dataLen);

                delete[] buffer;
            };
494
            pthread_pool->enqueue(recvTask);
495
496
497
498
499
500
501
502
        } else {
            // 自己的数据直接放置到正确位置
            memcpy(static_cast<char*>(recvData) + localRank * dataLen, sendData, dataLen);
        }
    }

    if(wait) {
        // 等待所有任务完成
503
504
        while(!pthread_pool->allTasksCompleted()) {
            usleep(1000); // 每1毫秒检查一次任务完成状态
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        }
    }

    return scclSuccess;
}

/**
 * @brief 使用IPC套接字进行Allgather同步操作
 *
 * 该函数实现了基于IPC套接字的Allgather同步操作,将各进程的数据收集到所有进程的接收缓冲区中。
 *
 * @param sendData 发送数据缓冲区指针
 * @param recvData 接收数据缓冲区指针
 * @param dataLen 每个进程发送/接收的数据长度
 * @param wait 是否等待所有通信任务完成
 * @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回错误码
 *
 * @note 1. 函数会先将本地数据复制到接收缓冲区对应位置
 *       2. 使用线程池并行处理与其他进程的通信任务
 *       3. 当wait为true时会阻塞等待所有通信完成
 */
scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, void* recvData, size_t dataLen, bool wait) {
527
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
528
529
530
531
532
533
534
535
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }

    // 将当前进程的数据复制到接收缓冲区的对应位置
    memcpy(static_cast<char*>(recvData) + localRank * dataLen, sendData, dataLen);

    // 采用线程池发送和接收数据
536
    for(int i = 0; i < nlocalRanks; ++i) {
537
538
        if(i != localRank) {
            auto sendTask = [this, sendData, dataLen, i]() { scclIpcSocketSendData(sendData, dataLen, i); };
539
            pthread_pool->enqueue(sendTask);
540
541
542
543
544

            auto recvTask = [this, recvData, dataLen, i]() {
                size_t receivedLen;
                scclIpcSocketRecvData(reinterpret_cast<char*>(recvData) + i * dataLen, dataLen, &receivedLen);
            };
545
            pthread_pool->enqueue(recvTask);
546
547
548
549
550
        }
    }

    if(wait) {
        // 等待所有任务完成
551
552
        while(!pthread_pool->allTasksCompleted()) {
            usleep(1000); // 每1毫秒检查一次任务完成状态
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        }
    }

    return scclSuccess;
}

/**
 * @brief 通过IPC Socket进行广播操作
 *
 * 该函数实现了基于IPC Socket的广播通信机制。根进程(root)将数据发送给所有其他进程,
 * 非根进程从根进程接收数据。可以选择是否等待所有通信操作完成。
 *
 * @param sendData 发送数据缓冲区指针(根进程使用)
 * @param recvData 接收数据缓冲区指针(非根进程使用)
 * @param dataLen 数据长度(字节)
 * @param root 根进程的rank值
 * @param wait 是否等待所有通信操作完成
 *
 * @return scclResult_t 返回操作结果状态码
 *     - scclSuccess: 操作成功
 *     - scclInternalError: IPC Socket未初始化或本地rank数无效
 *     - scclInvalidArgument: 根进程rank值无效
 */
576
577
578
579
580
scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, int root, bool wait) {

    pthread_pool->allTasksCompleted();

    if(pthread_pool == nullptr || nlocalRanks <= 0) {
581
582
583
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }
584
    if(root < 0 || root >= nlocalRanks) {
585
586
587
588
        WARN("scclIpcSocketBroadcast: Invalid root rank %d", root);
        return scclInvalidArgument;
    }

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
    // if(localRank == root) {
    //     // 根进程:发送数据给所有其他进程
    //     for(int i = 0; i < nlocalRanks; ++i) {
    //         if(i != root) {
    //             // 使用 std::bind 绑定 scclIpcSocketSendDataWithAck 方法和参数
    //             auto sendTask = std::bind(&scclIpcSocket::scclIpcSocketSendDataWithAck, this, data, dataLen, i);
    //             // 将绑定后的函数对象添加到线程池的任务队列中
    //             pthread_pool->enqueue(sendTask);

    //             printf("send root: %d, i=%d\n", root, i);
    //         }
    //     }
    // } else {
    //     size_t receivedLen;
    //     scclResult_t result = scclIpcSocketRecvDataAndSendAck(data, dataLen, &receivedLen, root);
    //     if(result != scclSuccess) {
    //         return result;
    //     }
    //     printf("recv from root: localRank=%d\n", localRank);
    // }
609

610
611
612
613
    if(wait) {
        // 等待所有任务完成
        while(!pthread_pool->allTasksCompleted()) {
            usleep(1000); // 每1毫秒检查一次任务完成状态
614
615
616
        }
    }

617
618
619
620
621
622
623
624
625
    return scclSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////
scclResult_t scclIpcSocket::getScclIpcSocknameStr(int rank, uint64_t hash, char* out_str, int* out_len) {
    int len = snprintf(out_str, SCCL_IPC_SOCKNAME_LEN, "/tmp/sccl-socket-%d-%lx", rank, hash);
    if(len > (sizeof(my_cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot bind provided name to socket. Name too large");
        return scclInternalError;
626
627
    }

628
    *out_len = len;
629
630
631
632
633
634
635
    return scclSuccess;
}

} // namespace ipc_socket
} // namespace net
} // namespace hardware
} // namespace sccl