ipc_socket.cpp 26.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <pthread.h>
#include <stdlib.h>
#include <poll.h>
#include <limits.h>
#include <fcntl.h>
#include <thread> // 为了使用 std::this_thread::sleep_for
#include "ipc_socket.h"

namespace sccl {
namespace hardware {
namespace net {

namespace ipc_socket {

//////////////////////////////////////// scclIpcSocket调用的函数 ////////////////////////////////////////
16
17
scclIpcSocket::scclIpcSocket(int localRank, int nlocalRanks, uint64_t hash, volatile uint32_t* abortFlag)
    : localRank(localRank), nlocalRanks(nlocalRanks), ipc_hash(hash) {
18
    scclResult_t res;
19
20
21
22
    // 初始化handle
    handle                = new struct scclIpcSocketHandle();
    handle->fd            = -1;
    handle->socketName[0] = '\0';
23

24
25
26
27
28
    // 设置线程池
    if(nlocalRanks > 0) {
        pthread_pool = new ThreadPool(nlocalRanks * 2); // 其中一半用于发送一半,用于接收
    } else {
        goto failure;
29
30
31
32
33
34
35
36
37
38
39
    }

    SCCLCHECKGOTO(scclIpcSocketInit(abortFlag), res, failure);
    return;

failure:
    WARN("scclIpcSocket init failed");
    return;
}

scclIpcSocket::~scclIpcSocket() {
40
41
42
43
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
    }
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    // 释放pthpool
    if(pthread_pool) {
        delete(pthread_pool);
    }

    // 释放handle
    if(handle->socketName[0] != '\0') {
        unlink(handle->socketName);
    }
    if(handle->fd >= 0) {
        close(handle->fd);
    }
    delete(handle);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

scclResult_t scclIpcSocket::scclIpcSocketInit(volatile uint32_t* abortFlag) {
    // 中间变量
    int fd = -1;
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];

    // 创建Unix域套接字
    // af是本机IP地址类型,一般有PF_INET或者AF_INET(IPv4互联网协议族),还有PF_INET6(IPv6互联网协议族)等,但是一般用IPv4。
    // type有两种SOCK_STREAM 和SOCK_DGRAM分别对应tcp和udp协议,区别是用不用建立连接。
    if((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
        WARN("UDS: Socket creation error : %d", errno);
        return scclSystemError;
    }

    // 将cliaddr结构体清零,确保没有残留数据
    bzero(&my_cliaddr, sizeof(my_cliaddr));
    my_cliaddr.sun_family = AF_UNIX;

    // 为套接字创建唯一名称
80
81
    int len;
    SCCLCHECK(getScclIpcSocknameStr(localRank, ipc_hash, temp_addr, &len));
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Creating socket %s", temp_addr);

    // 设置套接字路径
    strncpy(my_cliaddr.sun_path, temp_addr, len);
    my_cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧

    // 绑定套接字
    if(bind(fd, (struct sockaddr*)&my_cliaddr, sizeof(my_cliaddr)) < 0) {
        WARN("UDS: Binding to socket %s failed : %d", temp_addr, errno);
        close(fd);
        return scclSystemError;
    }

    // 设置handle的成员变量
    handle->fd = fd;
    strcpy(handle->socketName, temp_addr);

    // 设置中止标志
    handle->abortFlag = abortFlag;
    // 将套接字标记为非阻塞
    if(handle->abortFlag) {
        int flags;
        EQCHECK(flags = fcntl(fd, F_GETFL), -1);
        SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
    }

    return scclSuccess;
}

/**
 * 设置中止标志并更新socket的非阻塞模式
 *
 * @param flag 指向中止标志的指针。如果非空,将socket设为非阻塞模式;
 *             如果为空,则恢复为阻塞模式。
 * @note 该函数仅在handle有效时执行操作
 */
scclResult_t scclIpcSocket::setAbortFlag(volatile uint32_t* flag) {
    if(handle) {
        handle->abortFlag = flag;
        if(flag) {
            int flags;
            EQCHECK(flags = fcntl(handle->fd, F_GETFL), -1);
            SYSCHECK(fcntl(handle->fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
        } else {
            int flags;
            EQCHECK(flags = fcntl(handle->fd, F_GETFL), -1);
            SYSCHECK(fcntl(handle->fd, F_SETFL, flags & ~O_NONBLOCK), "fcntl");
        }
    }
    return scclSuccess;
}

// 获取 abortFlag 的函数
volatile uint32_t* scclIpcSocket::getAbortFlag() const { return handle ? handle->abortFlag : nullptr; }

/**
 * 设置IPC套接字的超时时间
 *
 * @param timeout_ms 超时时间(毫秒)
 * @return 成功返回scclSuccess
 */
scclResult_t scclIpcSocket::setTimeout(int timeout_ms) {
    timeoutMs = timeout_ms;
    return scclSuccess;
}

ThreadPool* scclIpcSocket::getPthreadPool() { return pthread_pool; }

//////////////////////////////////////////////////////////////////////////////////////////////////////
/**
 * @brief 通过Unix域套接字发送文件描述符
 *
 * @param sendFd 要发送的文件描述符
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如地址过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误
 *
 * @note 使用Linux抽象套接字技巧(将sun_path[0]置为'\0')
 *       通过SCM_RIGHTS机制发送文件描述符
 *       函数会循环尝试发送直到成功或遇到错误
 */
scclResult_t scclIpcSocket::scclIpcSocketSendFd(const int sendFd, int dst_rank) {
    // 创建一个临时地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
    // 格式化地址字符串
169
170
171
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    // 记录发送文件描述符的信息
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Sending fd %d to UDS socket %s/fd:%d", sendFd, temp_addr, handle->fd);

    // 初始化消息头结构体和iovec结构体
    struct msghdr msg;
    struct iovec iov[1];

    // 联合体用于保证控制数组的对齐要求
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    struct sockaddr_un cliaddr;

    // 构造客户端地址以发送共享句柄
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧

    // 设置消息头的控制信息部分
    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    cmptr             = CMSG_FIRSTHDR(&msg);
    cmptr->cmsg_len   = CMSG_LEN(sizeof(int));
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type  = SCM_RIGHTS;

    // 将要发送的文件描述符复制到控制信息中
    memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd));

    // 设置消息头的地址信息部分
    msg.msg_name    = (void*)&cliaddr;
    msg.msg_namelen = sizeof(struct sockaddr_un);

    // 设置iovec结构体,用于指定要发送的数据
    iov[0].iov_base = (void*)"";
    iov[0].iov_len  = 1;

    // 将iovec结构体关联到消息头
    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    // 初始化消息标志
    msg.msg_flags = 0;

    ssize_t sendResult;
    // 循环发送消息,直到成功发送数据
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        // 如果发送失败且错误不是EAGAIN, EWOULDBLOCK或EINTR,则记录警告并返回错误
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Sending data over socket %s failed : %d", temp_addr, errno);
            return scclSystemError;
        }
        // 如果设置了中止标志,则返回内部错误
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    return scclSuccess;
}

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
/**
 * @brief 通过IPC socket接收文件描述符
 *
 * 该函数使用recvmsg系统调用从socket接收文件描述符。函数会循环尝试接收,
 * 直到成功或发生错误。接收到的文件描述符会通过参数recvFd返回。
 *
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用失败
 *         - scclInternalError: 操作被中止
 *
 * @note 函数会处理EAGAIN、EWOULDBLOCK和EINTR错误,其他错误会导致返回失败。
 *       接收到的控制消息必须符合SOL_SOCKET级别和SCM_RIGHTS类型。
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvFd(int* recvFd) {
    // 初始化消息头结构体和iovec结构体
    struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
    struct iovec iov[1];

    // 联合体用于保证控制数组的对齐要求
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    char dummy_buffer[1];
    int ret;

    // 设置消息头的控制信息部分
    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    // 设置iovec结构体,用于指定要接收的数据
    iov[0].iov_base = (void*)dummy_buffer;
    iov[0].iov_len  = sizeof(dummy_buffer);

    // 将iovec结构体关联到消息头
    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    // 循环接收消息,直到成功接收到数据
    while((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
        // 如果接收失败且错误不是EAGAIN, EWOULDBLOCK或EINTR,则记录警告并返回错误
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Receiving data over socket failed : %d", errno);
            return scclSystemError;
        }
        // 如果设置了中止标志,则返回内部错误
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    // 检查接收到的控制信息
    if(((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
        // 如果控制信息的级别或类型不正确,则记录警告并返回错误
        if((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
            WARN("UDS: Receiving data over socket failed");
            return scclSystemError;
        }

        // 将接收到的文件描述符复制到recvFd
        memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
    } else {
        // 如果没有接收到控制信息,则记录警告并返回错误
        WARN("UDS: Receiving data over socket %s failed", handle->socketName);
        return scclSystemError;
    }

    // 记录成功接收到文件描述符的信息
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);

    // 返回成功
    return scclSuccess;
}

//////////////////////////////////////////////////////////////////////////////////////////////////////
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
/**
 * @brief 通过IPC socket接收文件描述符
 *
 * 该函数使用recvmsg系统调用从socket接收文件描述符。函数会循环尝试接收,
 * 直到成功或发生错误。接收到的文件描述符会通过参数recvFd返回。
 *
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用失败
 *         - scclInternalError: 操作被中止
 *
 * @note 函数会处理EAGAIN、EWOULDBLOCK和EINTR错误,其他错误会导致返回失败。
 *       接收到的控制消息必须符合SOL_SOCKET级别和SCM_RIGHTS类型。
 */
scclResult_t scclIpcSocket::scclIpcSocketSendData(const void* data, size_t dataLen, int dst_rank) {
    // 构造目标地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
333
334
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));
335
336
337
338
339
340
341
342
343

    // 设置消息结构体
    struct msghdr msg;
    struct iovec iov[1];
    struct sockaddr_un cliaddr;
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧
344
345
346
347
348
    msg.msg_name        = (void*)&cliaddr;
    msg.msg_namelen     = sizeof(cliaddr);
    msg.msg_control     = NULL;
    msg.msg_controllen  = 0;
    msg.msg_flags       = 0;
349

350
351
352
353
    iov[0].iov_base = (void*)data;
    iov[0].iov_len  = dataLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

    // 使用 poll 等待 socket 可写
    struct pollfd pfd;
    pfd.fd     = handle->fd;
    pfd.events = POLLOUT;

    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);
        } else {
            WARN("UDS: Error occurred while polling socket %s for writability : %d", temp_addr, errno);
        }
        return scclSystemError;
    }

    ssize_t sendResult;
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Error occurred while sending data through socket %s : %d", temp_addr, errno);
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;

        pollResult = poll(&pfd, 1, timeoutMs);
        if(pollResult <= 0) {
            if(pollResult == 0) {
                WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);
            } else {
                WARN("UDS: Error occurred while polling socket %s for writability : %d", temp_addr, errno);
            }
            return scclSystemError;
        }
    }

390
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zd bytes of data through UDS socket %s", sendResult, temp_addr);
391
392
393
394
    return scclSuccess;
}

/**
395
 * @brief 通过IPC套接字发送数据到指定目标rank
396
 *
397
398
399
400
401
402
403
 * @param data 要发送的数据指针
 * @param dataLen 要发送的数据长度
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如套接字名称过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误(如poll超时或sendmsg失败)
404
 *
405
406
 * @note 使用Linux抽象套接字技术,通过poll机制确保套接字可写后再发送数据
 *       支持EAGAIN/EWOULDBLOCK/EINTR错误重试机制
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvData(void* buffer, size_t bufferLen, size_t* receivedLen) {
    // 设置消息结构体
    struct msghdr msg = {0};
    struct iovec iov[1];
    iov[0].iov_base = buffer;
    iov[0].iov_len  = bufferLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;

    // 使用 poll 等待 socket 可读
    struct pollfd pfd;
    pfd.fd     = handle->fd;
    pfd.events = POLLIN;

    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);
        } else {
            WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
        }
        return scclSystemError;
    }

    int ret;
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    while(true) {
        ret = recvmsg(handle->fd, &msg, 0);
        if(ret > 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %d bytes of data from socket %s", ret, handle->socketName);
            *receivedLen = ret;
            return scclSuccess;
        } else if(ret == 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Connection closed by peer on socket %s", handle->socketName);
            *receivedLen = 0;
            return scclSuccess;
        } else {
            if(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                pollResult = poll(&pfd, 1, timeoutMs);
                if(pollResult <= 0) {
                    if(pollResult == 0) {
                        WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);
                    } else {
                        WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
                    }
                    return scclSystemError;
                }
454
            } else {
455
456
                WARN("UDS: Error occurred while receiving data through socket %s : %d", handle->socketName, errno);
                return scclSystemError;
457
458
459
460
461
            }
        }
    }
}

462
463
464
465
466
467
468
469
470
471
472
473
/**
 * 通过IPC Socket发送数据并等待确认
 *
 * @param data 要发送的数据指针
 * @param dataLen 要发送的数据长度
 * @param dst_rank 目标rank号
 * @return scclSuccess 发送成功,其他错误码表示失败
 *
 * 该函数会将数据分块发送(CHUNK_SIZE大小),每发送一个数据块后
 * 会等待接收方返回ACK确认。如果收到非预期的ACK或发送/接收失败,
 * 会立即返回错误。所有数据成功发送并收到正确ACK后返回成功。
 */
474
scclResult_t scclIpcSocket::scclIpcSocketSendDataWithAck(const void* data, size_t dataLen, int dst_rank) {
475
476
    const char* dataPtr = static_cast<const char*>(data);
    size_t bytesSent    = 0;
477

478
479
    while(bytesSent < dataLen) {
        size_t bytesToSend = std::min(CHUNK_SIZE, dataLen - bytesSent);
480

481
482
483
484
485
        // 发送数据块
        scclResult_t sendResult = scclIpcSocketSendData(dataPtr + bytesSent, bytesToSend, dst_rank);
        if(sendResult != scclSuccess) {
            return sendResult;
        }
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        // 等待接收方的ACK
        char ack[ACK_SIZE];
        size_t receivedLen;
        scclResult_t recvResult = scclIpcSocketRecvData(ack, sizeof(ack), &receivedLen);
        if(recvResult != scclSuccess) {
            return recvResult;
        }

        // 检查是否是预期的ack
        char target_ack[ACK_SIZE];
        sprintf(target_ack, "ACK-%d", localRank);
        if(strcmp(ack, target_ack) != 0) {
            WARN("UDS: Received unexpected ACK: %s", ack);
            return scclSystemError;
        }

        bytesSent += bytesToSend;
504
    }
505

506
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zu bytes of data with ACK through UDS socket", dataLen);
507
508
    return scclSuccess;
}
509

510
511
512
513
514
515
516
517
518
519
520
521
/**
 * 通过IPC Socket接收数据并发送ACK确认
 *
 * @param buffer 接收数据缓冲区指针
 * @param bufferLen 缓冲区总长度
 * @param receivedLen 实际接收到的数据长度(输出参数)
 * @param src_rank 发送方rank号
 * @return scclSuccess表示成功,其他错误码表示失败
 *
 * @note 采用分块接收机制,每接收一个数据块都会发送ACK确认
 *       接收完成后会记录日志信息
 */
522
scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndSendAck(void* buffer, size_t bufferLen, size_t* receivedLen, int src_rank) {
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    char* bufferPtr      = static_cast<char*>(buffer);
    size_t bytesReceived = 0;

    while(bytesReceived < bufferLen) {
        size_t bytesToReceive = std::min(CHUNK_SIZE, bufferLen - bytesReceived);

        // 接收数据块
        scclResult_t recvResult = scclIpcSocketRecvData(bufferPtr + bytesReceived, bytesToReceive, receivedLen);
        if(recvResult != scclSuccess) {
            return recvResult;
        }

        // 发送ACK给发送方
        char ack[ACK_SIZE];
        sprintf(ack, "ACK-%d", src_rank);
        scclResult_t sendResult = scclIpcSocketSendData(ack, strlen(ack), src_rank);
        if(sendResult != scclSuccess) {
            return sendResult;
        }

        bytesReceived += *receivedLen;
544
    }
545

546
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %zu bytes of data and sent ACK through UDS socket", bufferLen);
547
    return scclSuccess;
548
549
}

550
/////////////////////////////////////////////////////////////////////////////////////////////////////
551
552
553
554
555
556
557
558
559
560
561
562
/**
 * @brief 使用IPC套接字实现Allgather操作
 *
 * 该函数通过线程池并行发送和接收数据,实现多节点间的Allgather集合通信。
 *
 * @param sendData 发送数据缓冲区指针
 * @param recvData 接收数据缓冲区指针
 * @param dataLen 每个节点的数据长度(字节)
 * @return scclResult_t 返回操作结果(scclSuccess表示成功)
 *
 * @note 1. 会跳过本地rank的数据传输
 *       2. 数据包格式: [发送rank(int)][数据]
563
 *       3. 接收缓冲区需要预先分配足够空间(大小=nlocalRanks*dataLen)
564
 */
565
scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen) {
566
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
567
568
569
570
571
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }

    // 采用线程池发送和接收数据
572
    for(int i = 0; i < nlocalRanks; ++i) {
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        if(i != localRank) {
            auto sendTask = [this, sendData, dataLen, i]() {
                // 计算 DataPackage 的总大小
                size_t packageSize = sizeof(int) + dataLen;
                char* buffer       = new char[packageSize];

                // 将 rank 信息和数据一起拷贝到 buffer 中
                int* rankPtr = reinterpret_cast<int*>(buffer);
                *rankPtr     = localRank;

                char* dataPtr = buffer + sizeof(int);
                memcpy(dataPtr, sendData, dataLen);

                // 一次性发送 rank 信息和数据
                scclIpcSocketSendData(buffer, packageSize, i);

                delete[] buffer;
            };
591
            pthread_pool->enqueue(sendTask);
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610

            auto recvTask = [this, recvData, dataLen, i]() {
                // 准备接收缓冲区
                size_t packageSize = sizeof(int) + dataLen;
                char* buffer       = new char[packageSize];
                size_t receivedLen;

                // 一次性接收 rank 信息和数据
                scclIpcSocketRecvData(buffer, packageSize, &receivedLen);

                // 从 buffer 中提取 rank 信息和数据
                int* rankPtr   = reinterpret_cast<int*>(buffer);
                int senderRank = *rankPtr;

                char* dataPtr = buffer + sizeof(int);
                memcpy(static_cast<char*>(recvData) + senderRank * dataLen, dataPtr, dataLen);

                delete[] buffer;
            };
611
            pthread_pool->enqueue(recvTask);
612
613
614
615
616
617
        } else {
            // 自己的数据直接放置到正确位置
            memcpy(static_cast<char*>(recvData) + localRank * dataLen, sendData, dataLen);
        }
    }

618
619
620
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    }

    return scclSuccess;
}

/**
 * @brief 使用IPC套接字进行Allgather同步操作
 *
 * 该函数实现了基于IPC套接字的Allgather同步操作,将各进程的数据收集到所有进程的接收缓冲区中。
 *
 * @param sendData 发送数据缓冲区指针
 * @param recvData 接收数据缓冲区指针
 * @param dataLen 每个进程发送/接收的数据长度
 * @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回错误码
 *
 * @note 1. 函数会先将本地数据复制到接收缓冲区对应位置
 *       2. 使用线程池并行处理与其他进程的通信任务
 *       3. 当wait为true时会阻塞等待所有通信完成
 */
640
scclResult_t scclIpcSocket::scclIpcSocketAllgatherSync(const void* sendData, void* recvData, size_t dataLen) {
641
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
642
643
644
645
646
647
648
649
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }

    // 将当前进程的数据复制到接收缓冲区的对应位置
    memcpy(static_cast<char*>(recvData) + localRank * dataLen, sendData, dataLen);

    // 采用线程池发送和接收数据
650
    for(int i = 0; i < nlocalRanks; ++i) {
651
652
        if(i != localRank) {
            auto sendTask = [this, sendData, dataLen, i]() { scclIpcSocketSendData(sendData, dataLen, i); };
653
            pthread_pool->enqueue(sendTask);
654
655
656
657
658

            auto recvTask = [this, recvData, dataLen, i]() {
                size_t receivedLen;
                scclIpcSocketRecvData(reinterpret_cast<char*>(recvData) + i * dataLen, dataLen, &receivedLen);
            };
659
            pthread_pool->enqueue(recvTask);
660
661
662
        }
    }

663
664
665
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
    }

    return scclSuccess;
}

/**
 * @brief 通过IPC Socket进行广播操作
 *
 * 该函数实现了基于IPC Socket的广播通信机制。根进程(root)将数据发送给所有其他进程,
 * 非根进程从根进程接收数据。可以选择是否等待所有通信操作完成。
 *
 * @param sendData 发送数据缓冲区指针(根进程使用)
 * @param recvData 接收数据缓冲区指针(非根进程使用)
 * @param dataLen 数据长度(字节)
 * @param root 根进程的rank值
 *
 * @return scclResult_t 返回操作结果状态码
 *     - scclSuccess: 操作成功
 *     - scclInternalError: IPC Socket未初始化或本地rank数无效
 *     - scclInvalidArgument: 根进程rank值无效
 */
687
scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, int root) {
688
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
689
690
691
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }
692
    if(root < 0 || root >= nlocalRanks) {
693
694
695
696
        WARN("scclIpcSocketBroadcast: Invalid root rank %d", root);
        return scclInvalidArgument;
    }

697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    if(localRank == root) {
        // 根进程:发送数据给所有其他进程
        for(int i = 0; i < nlocalRanks; ++i) {
            if(i != root) {
                // 使用 std::bind 绑定 scclIpcSocketSendDataWithAck 方法和参数
                auto sendTask = std::bind(&scclIpcSocket::scclIpcSocketSendDataWithAck, this, data, dataLen, i);
                // 将绑定后的函数对象添加到线程池的任务队列中
                pthread_pool->enqueue(sendTask);
            }
        }
    } else {
        size_t receivedLen;
        scclResult_t result = scclIpcSocketRecvDataAndSendAck(data, dataLen, &receivedLen, root);
        if(result != scclSuccess) {
            return result;
712
713
714
        }
    }

715
716
717
718
719
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
        usleep(1000); // 每1毫秒检查一次任务完成状态
    }

720
721
722
723
724
725
726
727
728
    return scclSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////
scclResult_t scclIpcSocket::getScclIpcSocknameStr(int rank, uint64_t hash, char* out_str, int* out_len) {
    int len = snprintf(out_str, SCCL_IPC_SOCKNAME_LEN, "/tmp/sccl-socket-%d-%lx", rank, hash);
    if(len > (sizeof(my_cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot bind provided name to socket. Name too large");
        return scclInternalError;
729
730
    }

731
    *out_len = len;
732
733
734
735
736
737
738
    return scclSuccess;
}

} // namespace ipc_socket
} // namespace net
} // namespace hardware
} // namespace sccl