ipc_socket.cpp 36.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include <pthread.h>
#include <stdlib.h>
#include <poll.h>
#include <limits.h>
#include <fcntl.h>
#include <thread> // 为了使用 std::this_thread::sleep_for
#include "ipc_socket.h"

namespace sccl {
namespace hardware {
namespace net {

namespace ipc_socket {

//////////////////////////////////////// scclIpcSocket调用的函数 ////////////////////////////////////////
16
17
scclIpcSocket::scclIpcSocket(int localRank, int nlocalRanks, uint64_t hash, volatile uint32_t* abortFlag)
    : localRank(localRank), nlocalRanks(nlocalRanks), ipc_hash(hash) {
18
    scclResult_t res;
19
20
21
22
    // 初始化handle
    handle                = new struct scclIpcSocketHandle();
    handle->fd            = -1;
    handle->socketName[0] = '\0';
23

24
25
26
27
28
    // 设置线程池
    if(nlocalRanks > 0) {
        pthread_pool = new ThreadPool(nlocalRanks * 2); // 其中一半用于发送一半,用于接收
    } else {
        goto failure;
29
30
31
32
33
34
35
36
37
38
39
    }

    SCCLCHECKGOTO(scclIpcSocketInit(abortFlag), res, failure);
    return;

failure:
    WARN("scclIpcSocket init failed");
    return;
}

scclIpcSocket::~scclIpcSocket() {
40
    printf("scclIpcSocket 析构函数 localRank=%d\n", localRank);
41
42
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
43
        usleep(100); // 每1毫秒检查一次任务完成状态
44
    }
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    // 释放pthpool
    if(pthread_pool) {
        delete(pthread_pool);
    }

    // 释放handle
    if(handle->socketName[0] != '\0') {
        unlink(handle->socketName);
    }
    if(handle->fd >= 0) {
        close(handle->fd);
    }
    delete(handle);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

scclResult_t scclIpcSocket::scclIpcSocketInit(volatile uint32_t* abortFlag) {
    // 中间变量
    int fd = -1;
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];

    // 创建Unix域套接字
    // af是本机IP地址类型,一般有PF_INET或者AF_INET(IPv4互联网协议族),还有PF_INET6(IPv6互联网协议族)等,但是一般用IPv4。
    // type有两种SOCK_STREAM 和SOCK_DGRAM分别对应tcp和udp协议,区别是用不用建立连接。
    if((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
72
        WARN("UDS: Socket creation error : %d (%s)", errno, strerror(errno));
73
74
75
76
77
78
79
80
        return scclSystemError;
    }

    // 将cliaddr结构体清零,确保没有残留数据
    bzero(&my_cliaddr, sizeof(my_cliaddr));
    my_cliaddr.sun_family = AF_UNIX;

    // 为套接字创建唯一名称
81
82
    int len;
    SCCLCHECK(getScclIpcSocknameStr(localRank, ipc_hash, temp_addr, &len));
83
84
85
86
87
88
89
90
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Creating socket %s", temp_addr);

    // 设置套接字路径
    strncpy(my_cliaddr.sun_path, temp_addr, len);
    my_cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧

    // 绑定套接字
    if(bind(fd, (struct sockaddr*)&my_cliaddr, sizeof(my_cliaddr)) < 0) {
91
        WARN("UDS: Binding to socket %s failed : %d (%s)", temp_addr, errno, strerror(errno));
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        close(fd);
        return scclSystemError;
    }

    // 设置handle的成员变量
    handle->fd = fd;
    strcpy(handle->socketName, temp_addr);

    // 设置中止标志
    handle->abortFlag = abortFlag;
    // 将套接字标记为非阻塞
    if(handle->abortFlag) {
        int flags;
        EQCHECK(flags = fcntl(fd, F_GETFL), -1);
        SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
    }

    return scclSuccess;
}

/**
 * 设置中止标志并更新socket的非阻塞模式
 *
 * @param flag 指向中止标志的指针。如果非空,将socket设为非阻塞模式;
 *             如果为空,则恢复为阻塞模式。
 * @note 该函数仅在handle有效时执行操作
 */
scclResult_t scclIpcSocket::setAbortFlag(volatile uint32_t* flag) {
    if(handle) {
        handle->abortFlag = flag;
        if(flag) {
            int flags;
            EQCHECK(flags = fcntl(handle->fd, F_GETFL), -1);
            SYSCHECK(fcntl(handle->fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
        } else {
            int flags;
            EQCHECK(flags = fcntl(handle->fd, F_GETFL), -1);
            SYSCHECK(fcntl(handle->fd, F_SETFL, flags & ~O_NONBLOCK), "fcntl");
        }
    }
    return scclSuccess;
}

// 获取 abortFlag 的函数
volatile uint32_t* scclIpcSocket::getAbortFlag() const { return handle ? handle->abortFlag : nullptr; }

/**
 * 设置IPC套接字的超时时间
 *
 * @param timeout_ms 超时时间(毫秒)
 * @return 成功返回scclSuccess
 */
scclResult_t scclIpcSocket::setTimeout(int timeout_ms) {
    timeoutMs = timeout_ms;
    return scclSuccess;
}

ThreadPool* scclIpcSocket::getPthreadPool() { return pthread_pool; }

//////////////////////////////////////////////////////////////////////////////////////////////////////
/**
 * @brief 通过Unix域套接字发送文件描述符
 *
 * @param sendFd 要发送的文件描述符
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如地址过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误
 *
 * @note 使用Linux抽象套接字技巧(将sun_path[0]置为'\0')
 *       通过SCM_RIGHTS机制发送文件描述符
 *       函数会循环尝试发送直到成功或遇到错误
 */
scclResult_t scclIpcSocket::scclIpcSocketSendFd(const int sendFd, int dst_rank) {
    // 创建一个临时地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
    // 格式化地址字符串
170
171
172
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    // 记录发送文件描述符的信息
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Sending fd %d to UDS socket %s/fd:%d", sendFd, temp_addr, handle->fd);

    // 初始化消息头结构体和iovec结构体
    struct msghdr msg;
    struct iovec iov[1];

    // 联合体用于保证控制数组的对齐要求
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    struct sockaddr_un cliaddr;

    // 构造客户端地址以发送共享句柄
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧

    // 设置消息头的控制信息部分
    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    cmptr             = CMSG_FIRSTHDR(&msg);
    cmptr->cmsg_len   = CMSG_LEN(sizeof(int));
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type  = SCM_RIGHTS;

    // 将要发送的文件描述符复制到控制信息中
    memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd));

    // 设置消息头的地址信息部分
    msg.msg_name    = (void*)&cliaddr;
    msg.msg_namelen = sizeof(struct sockaddr_un);

    // 设置iovec结构体,用于指定要发送的数据
    iov[0].iov_base = (void*)"";
    iov[0].iov_len  = 1;

    // 将iovec结构体关联到消息头
    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    // 初始化消息标志
    msg.msg_flags = 0;

    ssize_t sendResult;
    // 循环发送消息,直到成功发送数据
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        // 如果发送失败且错误不是EAGAIN, EWOULDBLOCK或EINTR,则记录警告并返回错误
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
227
            WARN("UDS: Sending data over socket %s failed : %d (%s)", temp_addr, errno, strerror(errno));
228
229
230
231
232
233
234
235
236
237
            return scclSystemError;
        }
        // 如果设置了中止标志,则返回内部错误
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    return scclSuccess;
}

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
/**
 * @brief 通过IPC socket接收文件描述符
 *
 * 该函数使用recvmsg系统调用从socket接收文件描述符。函数会循环尝试接收,
 * 直到成功或发生错误。接收到的文件描述符会通过参数recvFd返回。
 *
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用失败
 *         - scclInternalError: 操作被中止
 *
 * @note 函数会处理EAGAIN、EWOULDBLOCK和EINTR错误,其他错误会导致返回失败。
 *       接收到的控制消息必须符合SOL_SOCKET级别和SCM_RIGHTS类型。
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvFd(int* recvFd) {
    // 初始化消息头结构体和iovec结构体
    struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
    struct iovec iov[1];

    // 联合体用于保证控制数组的对齐要求
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    char dummy_buffer[1];
    int ret;

    // 设置消息头的控制信息部分
    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    // 设置iovec结构体,用于指定要接收的数据
    iov[0].iov_base = (void*)dummy_buffer;
    iov[0].iov_len  = sizeof(dummy_buffer);

    // 将iovec结构体关联到消息头
    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    // 循环接收消息,直到成功接收到数据
    while((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
        // 如果接收失败且错误不是EAGAIN, EWOULDBLOCK或EINTR,则记录警告并返回错误
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
284
            WARN("UDS: Receiving data over socket failed : %d (%s)", errno, strerror(errno));
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            return scclSystemError;
        }
        // 如果设置了中止标志,则返回内部错误
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    // 检查接收到的控制信息
    if(((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
        // 如果控制信息的级别或类型不正确,则记录警告并返回错误
        if((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
            WARN("UDS: Receiving data over socket failed");
            return scclSystemError;
        }

        // 将接收到的文件描述符复制到recvFd
        memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
    } else {
        // 如果没有接收到控制信息,则记录警告并返回错误
        WARN("UDS: Receiving data over socket %s failed", handle->socketName);
        return scclSystemError;
    }

    // 记录成功接收到文件描述符的信息
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);

    // 返回成功
    return scclSuccess;
}

315
scclResult_t scclIpcSocket::scclIpcSocketSendData(const void* data, size_t dataLen, int dst_rank) {
316
    const char* dataPtr = reinterpret_cast<const char*>(data);
317
318
319
320
    size_t bytesSent    = 0;
    while(bytesSent < dataLen) {
        size_t bytesToSend = std::min(CHUNK_SIZE, dataLen - bytesSent);
        // 发送数据块
321
        scclResult_t sendResult = scclIpcSocketSendDataAndRank(dataPtr + bytesSent, bytesToSend, dst_rank);
322
323
324
325
        if(sendResult != scclSuccess) {
            return sendResult;
        }
        bytesSent += bytesToSend;
326
    }
327
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zu bytes of data in chunks through UDS socket", dataLen);
328
329
    return scclSuccess;
}
330

331
332
scclResult_t scclIpcSocket::scclIpcSocketRecvData(void* buffer, size_t bufferLen, size_t* receivedLen, int* src_rank) {
    char* bufferPtr      = reinterpret_cast<char*>(buffer);
333
    size_t bytesReceived = 0;
334
    *receivedLen         = 0;
335
336
    while(bytesReceived < bufferLen) {
        size_t bytesToReceive = std::min(CHUNK_SIZE, bufferLen - bytesReceived);
337
        int recv_rank         = -1;
338
        // 接收数据块
339
340
341
        scclResult_t recvResult = scclIpcSocketRecvDataAndRank(bufferPtr + bytesReceived, bytesToReceive, receivedLen, &recv_rank);
        *src_rank               = recv_rank;

342
343
344
345
        if(recvResult != scclSuccess) {
            return recvResult;
        }
        bytesReceived += *receivedLen;
346
    }
347
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %zu bytes of data in chunks through UDS socket", bufferLen);
348
    return scclSuccess;
349
350
351
}

/**
352
 * @brief 通过Unix域套接字发送数据到指定目标,并等待ACK确认信息
353
 *
354
355
 * 该函数通过Unix域套接字发送数据到指定的目标rank,并等待接收ACK确认信息。
 * 如果接收到的ACK确认信息不正确,函数将返回错误。
356
 *
357
358
359
360
361
362
 * @param data 要发送的数据指针
 * @param dataLen 要发送的数据长度
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 操作成功
 *         - scclSystemError: 系统调用错误
363
 */
364
365
366
367
368
scclResult_t scclIpcSocket::scclIpcSocketSendDataWithAck(const void* data, size_t dataLen, int dst_rank) {
    // 发送数据和rank信息
    scclResult_t sendResult = scclIpcSocketSendDataAndRank(data, dataLen, dst_rank);
    if(sendResult != scclSuccess) {
        return sendResult;
369
370
    }

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    // 等待ACK
    char ack[ACK_SIZE];
    size_t ackLen;
    int ack_rank;
    scclResult_t recvAckResult = scclIpcSocketRecvDataAndRank(ack, ACK_SIZE, &ackLen, &ack_rank);
    if(recvAckResult != scclSuccess || ack_rank != dst_rank) {
        WARN("UDS: Failed to receive ACK from rank ack_rank:%d, dst_rank:%d", ack_rank, dst_rank);
        return scclSystemError;
    }
#if 0
    printf("scclIpcSocketSendDataWithAck localRank=%d, dst_rank=%d, ack_rank=%d, ack=%s\n", localRank, dst_rank, ack_rank, ack);
#endif
    // 对比ACK的字符串
    char expectedAck[ACK_SIZE];
    snprintf(expectedAck, ACK_SIZE, "ACK-%d", ack_rank);
    if(strncmp(ack, expectedAck, ACK_SIZE) != 0) {
        WARN("UDS: Received incorrect ACK from rank %d", dst_rank);
        return scclSystemError;
    }
390

391
392
    return scclSuccess;
}
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
/**
 * @brief 通过Unix域套接字接收数据,并发送ACK确认信息
 *
 * 该函数通过Unix域套接字接收数据,并在接收完成后发送ACK确认信息给发送端。
 * 如果发送ACK确认信息失败,函数将返回错误。
 *
 * @param buffer 用于存储接收数据的缓冲区指针
 * @param bufferLen 缓冲区长度
 * @param receivedLen 接收到的数据长度(由函数设置)
 * @param src_rank 数据发送端的rank号(由函数设置)
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 操作成功
 *         - scclSystemError: 系统调用错误
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvDataWithAck(void* buffer, size_t bufferLen, size_t* receivedLen, int* src_rank) {
    // 接收数据和rank信息
    scclResult_t recvResult = scclIpcSocketRecvDataAndRank(buffer, bufferLen, receivedLen, src_rank);
    if(recvResult != scclSuccess) {
        return recvResult;
413
    }
414
415
416
417
418
419
420
421
422
423
#if 0
    printf("scclIpcSocketRecvDataWithAck localRank=%d, src_rank=%d, bufferLen=%zu, receivedLen=%zu\n", localRank, *src_rank, bufferLen, *receivedLen);
#endif
    // 发送ACK
    char ack[ACK_SIZE];
    snprintf(ack, ACK_SIZE, "ACK-%d", localRank);
    scclResult_t sendAckResult = scclIpcSocketSendDataAndRank(ack, ACK_SIZE, *src_rank);
    if(sendAckResult != scclSuccess) {
        WARN("UDS: Failed to send ACK to rank %d", *src_rank);
        return scclSystemError;
424
425
426
427
428
    }

    return scclSuccess;
}

429
/////////////////////////////////////////////////////////////////////////////////////////////////////
430
431
432
433
434
435
436
437
438
439
440
441
442
443
/**
 * @brief 使用IPC套接字进行Allgather同步操作
 *
 * 该函数实现了基于IPC套接字的Allgather同步操作,将各进程的数据收集到所有进程的接收缓冲区中。
 *
 * @param sendData 发送数据缓冲区指针
 * @param recvData 接收数据缓冲区指针
 * @param dataLen 每个进程发送/接收的数据长度
 * @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回错误码
 *
 * @note 1. 函数会先将本地数据复制到接收缓冲区对应位置
 *       2. 使用线程池并行处理与其他进程的通信任务
 *       3. 当wait为true时会阻塞等待所有通信完成
 */
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
#if 1
// TODO: 当前为了保证正确性,性能太慢,后续优化
scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen) {
    if(nlocalRanks <= 0) {
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }

    // 将当前进程的数据复制到接收缓冲区的对应位置
    auto all_recv_data = reinterpret_cast<char*>(recvData);
    memcpy(all_recv_data + localRank * dataLen, sendData, dataLen);

    // 当前rank的传输目标
    int next_rank_for_send = (localRank + 1 + nlocalRanks) % nlocalRanks;

    // Ring Allgather
    for(int step = 0; step < nlocalRanks - 1; ++step) {
        int next_rank_for_data = (localRank - step + nlocalRanks) % nlocalRanks;
        int prev_rank_for_data = (localRank - step - 1 + nlocalRanks) % nlocalRanks;
        // 准备发送/接收的数据
        auto send_data = all_recv_data + next_rank_for_data * dataLen;
        auto recv_data = all_recv_data + prev_rank_for_data * dataLen;

        auto sendTask = [this, send_data, dataLen, next_rank_for_send]() { scclIpcSocketSendDataBasic(send_data, dataLen, next_rank_for_send); };
        pthread_pool->enqueue(sendTask);

        auto recvTask = [this, recv_data, dataLen]() {
            size_t receivedLen;
            int recv_rank;
            scclIpcSocketRecvDataBasic(recv_data, dataLen, &receivedLen);
        };
        pthread_pool->enqueue(recvTask);

        // 等待所有任务完成
        while(!pthread_pool->allTasksCompleted()) {
            usleep(100); // 每1毫秒检查一次任务完成状态
        }
    }

    return scclSuccess;
}
#else
scclResult_t scclIpcSocket::scclIpcSocketAllgather(const void* sendData, void* recvData, size_t dataLen) {
487
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
488
489
490
491
492
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }

    // 将当前进程的数据复制到接收缓冲区的对应位置
493
494
495
496
    auto all_recv_data = reinterpret_cast<char*>(recvData);
    memcpy(all_recv_data + localRank * dataLen, sendData, dataLen);
    char* temp_recvData;
    SCCLCHECK(scclCalloc(&temp_recvData, dataLen));
497
498

    // 采用线程池发送和接收数据
499
    for(int i = 0; i < nlocalRanks; ++i) {
500
        if(i != localRank) {
501
            auto sendTask = [this, sendData, dataLen, i]() { scclIpcSocketSendDataAndRank(sendData, dataLen, i); };
502
            pthread_pool->enqueue(sendTask);
503

504
            auto recvTask = [this, all_recv_data, dataLen, i, &temp_recvData]() {
505
                size_t receivedLen;
506
507
508
509
510
511
                int recv_rank;
                scclIpcSocketRecvDataAndRank(temp_recvData, dataLen, &receivedLen, &recv_rank);
                // printf("localRank=%d, recv_rank=%d, dataLen=%zu\n", localRank, recv_rank, dataLen);

                // 将数据拷贝到目标地址
                memcpy(all_recv_data + recv_rank * dataLen, temp_recvData, dataLen);
512
            };
513
            pthread_pool->enqueue(recvTask);
514
515
516
        }
    }

517
518
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
519
        usleep(100); // 每1毫秒检查一次任务完成状态
520
521
    }

522
523
    free(temp_recvData);

524
525
    return scclSuccess;
}
526
#endif
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

/**
 * @brief 通过IPC Socket进行广播操作
 *
 * 该函数实现了基于IPC Socket的广播通信机制。根进程(root)将数据发送给所有其他进程,
 * 非根进程从根进程接收数据。可以选择是否等待所有通信操作完成。
 *
 * @param sendData 发送数据缓冲区指针(根进程使用)
 * @param recvData 接收数据缓冲区指针(非根进程使用)
 * @param dataLen 数据长度(字节)
 * @param root 根进程的rank值
 *
 * @return scclResult_t 返回操作结果状态码
 *     - scclSuccess: 操作成功
 *     - scclInternalError: IPC Socket未初始化或本地rank数无效
 *     - scclInvalidArgument: 根进程rank值无效
 */
544
#if 1
545
scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, int root) {
546
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
547
548
549
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }
550
    if(root < 0 || root >= nlocalRanks) {
551
552
553
554
        WARN("scclIpcSocketBroadcast: Invalid root rank %d", root);
        return scclInvalidArgument;
    }

555
556
557
    if(localRank == root) {
        // 根进程:发送数据给所有其他进程
        for(int i = 0; i < nlocalRanks; ++i) {
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
            if(i != root) {
                const char* dataPtr = reinterpret_cast<const char*>(data);
                size_t bytesSent    = 0;
                while(bytesSent < dataLen) {
                    size_t bytesToSend      = std::min(CHUNK_SIZE, dataLen - bytesSent);
                    scclResult_t sendResult = scclIpcSocketSendDataWithAck(dataPtr + bytesSent, bytesToSend, i);
                    if(sendResult != scclSuccess) {
                        return sendResult;
                    }
                    bytesSent += bytesToSend;
                }
            }
        }
    } else {
        char* dataPtr        = reinterpret_cast<char*>(data);
        size_t bytesReceived = 0;
        while(bytesReceived < dataLen) {
            size_t bytesToReceive = std::min(CHUNK_SIZE, dataLen - bytesReceived);
            size_t receivedLen;
            int receivedRank;
            scclResult_t recvResult = scclIpcSocketRecvDataWithAck(dataPtr + bytesReceived, bytesToReceive, &receivedLen, &receivedRank);
            if(recvResult != scclSuccess) {
                return recvResult;
            }
            bytesReceived += receivedLen;
        }
    }

    return scclSuccess;
}

#else
scclResult_t scclIpcSocket::scclIpcSocketBroadcast(void* data, size_t dataLen, int root) {
    if(pthread_pool == nullptr || nlocalRanks <= 0) {
        WARN("scclIpcSocket init error!");
        return scclInternalError;
    }
    if(root < 0 || root >= nlocalRanks) {
        WARN("scclIpcSocketBroadcast: Invalid root rank %d", root);
        return scclInvalidArgument;
    }

    if(localRank == root) {
        // 根进程:发送数据给所有其他进程
        for(int i = 0; i < nlocalRanks; ++i) {
            // scclIpcSocketSendDataWithAck(data, dataLen, i);
604
605
606
607
608
609
610
611
612
            if(i != root) {
                // 使用 std::bind 绑定 scclIpcSocketSendDataWithAck 方法和参数
                auto sendTask = std::bind(&scclIpcSocket::scclIpcSocketSendDataWithAck, this, data, dataLen, i);
                // 将绑定后的函数对象添加到线程池的任务队列中
                pthread_pool->enqueue(sendTask);
            }
        }
    } else {
        size_t receivedLen;
613
614
        int receivedRank;
        scclResult_t result = scclIpcSocketRecvDataWithAck(data, dataLen, &receivedLen, &receivedRank);
615
616
        if(result != scclSuccess) {
            return result;
617
618
619
        }
    }

620
621
    // 等待所有任务完成
    while(!pthread_pool->allTasksCompleted()) {
622
        usleep(100); // 每1毫秒检查一次任务完成状态
623
624
    }

625
626
    return scclSuccess;
}
627
#endif
628
629
630
631
632
633
634

/////////////////////////////////////////////////////////////////////////////////////
scclResult_t scclIpcSocket::getScclIpcSocknameStr(int rank, uint64_t hash, char* out_str, int* out_len) {
    int len = snprintf(out_str, SCCL_IPC_SOCKNAME_LEN, "/tmp/sccl-socket-%d-%lx", rank, hash);
    if(len > (sizeof(my_cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot bind provided name to socket. Name too large");
        return scclInternalError;
635
636
    }

637
    *out_len = len;
638
639
640
    return scclSuccess;
}

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
//////////////////////////////////////////////////////////////////////////////////////////////////////
/**
 * @brief 通过IPC套接字发送数据到指定目标rank
 *
 * 该函数通过Unix域套接字(UDS)发送数据到指定的目标rank。它首先构造目标地址字符串,
 * 然后设置消息结构体,包括目标地址和要发送的数据。接着,使用poll机制等待套接字可写,
 * 然后通过sendmsg函数发送数据。如果发送过程中出现错误,函数将根据错误类型采取相应的措施,
 * 包括重试发送或返回错误码。
 *
 * @param data 要发送的数据指针
 * @param dataLen 要发送的数据长度
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如套接字名称过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误(如poll超时或sendmsg失败)
 *
 * @note 使用Linux抽象套接字技术,通过poll机制确保套接字可写后再发送数据
 *       支持EAGAIN/EWOULDBLOCK/EINTR错误重试机制
 */
scclResult_t scclIpcSocket::scclIpcSocketSendDataBasic(const void* data, size_t dataLen, int dst_rank) {
    // 构造目标地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));
    // 设置消息结构体
    struct msghdr msg;
    struct iovec iov[1]; // 修改为1
    struct sockaddr_un cliaddr;
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧
    msg.msg_name        = (void*)&cliaddr;
    msg.msg_namelen     = sizeof(cliaddr);
    msg.msg_control     = NULL;
    msg.msg_controllen  = 0;
    msg.msg_flags       = 0;
    // 准备数据
    iov[0].iov_base = (void*)data;
    iov[0].iov_len  = dataLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1; // 修改为1
    // 使用 poll 等待 socket 可写
    struct pollfd pfd;
    pfd.fd         = handle->fd;
    pfd.events     = POLLOUT;
    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);
        } else {
693
            WARN("UDS: Error occurred while polling socket %s for writability : %d (%s)", temp_addr, errno, strerror(errno));
694
695
696
697
698
699
        }
        return scclSystemError;
    }
    ssize_t sendResult;
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
700
            WARN("UDS: Error occurred while sending data through socket %s : %d (%s)", temp_addr, errno, strerror(errno));
701
702
703
704
705
706
707
708
709
710
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag) {
            return scclInternalError;
        }
        pollResult = poll(&pfd, 1, timeoutMs);
        if(pollResult <= 0) {
            if(pollResult == 0) {
                WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);
            } else {
711
                WARN("UDS: Error occurred while polling socket %s for writability : %d (%s)", temp_addr, errno, strerror(errno));
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            }
            return scclSystemError;
        }
    }
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zd bytes of data through UDS socket %s", sendResult, temp_addr);
    return scclSuccess;
}

/**
 * @brief 通过IPC套接字接收数据
 *
 * 该函数通过Unix域套接字(UDS)接收数据。它首先设置消息结构体,包括接收缓冲区。
 * 然后,使用poll机制等待套接字可读,接着通过recvmsg函数接收数据。如果接收过程中出现错误,
 * 函数将根据错误类型采取相应的措施,包括重试接收或返回错误码。
 *
 * @param buffer 用于存储接收数据的缓冲区指针
 * @param bufferLen 缓冲区长度
 * @param receivedLen 接收的数据长度(由函数设置)
 * @param src_rank 数据发送端的rank号(由函数设置)
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 接收成功
 *         - scclSystemError: 系统调用错误(如poll超时或recvmsg失败)
 *
 * @note 使用Linux抽象套接字技术,通过poll机制确保套接字可读后再接收数据
 *       支持EAGAIN/EWOULDBLOCK/EINTR错误重试机制
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvDataBasic(void* buffer, size_t bufferLen, size_t* receivedLen) {
    // 设置消息结构体
    struct msghdr msg = {0};
    struct iovec iov[1]; // 修改为1
    iov[0].iov_base = buffer;
    iov[0].iov_len  = bufferLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1; // 修改为1
    // 使用 poll 等待 socket 可读
    struct pollfd pfd;
    pfd.fd         = handle->fd;
    pfd.events     = POLLIN;
    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);
        } else {
            WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
        }
        return scclSystemError;
    }
    int ret;
    while(true) {
        ret = recvmsg(handle->fd, &msg, 0);
        if(ret > 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %d bytes of data from socket %s", ret, handle->socketName);
            *receivedLen = ret; // 不再减去rank信息的长度
            // *src_rank    = rank; // 移除此行
            // 设置发送端的rank信息
            return scclSuccess;
        } else if(ret == 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Connection closed by peer on socket %s", handle->socketName);
            *receivedLen = 0;
            return scclSuccess;
        } else {
            if(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                pollResult = poll(&pfd, 1, timeoutMs);
                if(pollResult <= 0) {
                    if(pollResult == 0) {
                        WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);
                    } else {
                        WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
                    }
                    return scclSystemError;
                }
            } else {
                WARN("UDS: Error occurred while receiving data through socket %s : %d", handle->socketName, errno);
                return scclSystemError;
            }
        }
    }
}

/**
 * @brief 通过IPC socket接收文件描述符
 *
 * 该函数使用recvmsg系统调用从socket接收文件描述符。函数会循环尝试接收,
 * 直到成功或发生错误。接收到的文件描述符会通过参数recvFd返回。
 *
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用失败
 *         - scclInternalError: 操作被中止
 *
 * @note 函数会处理EAGAIN、EWOULDBLOCK和EINTR错误,其他错误会导致返回失败。
 *       接收到的控制消息必须符合SOL_SOCKET级别和SCM_RIGHTS类型。
 */
scclResult_t scclIpcSocket::scclIpcSocketSendDataAndRank(const void* data, size_t dataLen, int dst_rank) {
    // 构造目标地址字符串
    char temp_addr[SCCL_IPC_SOCKNAME_LEN];
    int len;
    SCCLCHECK(getScclIpcSocknameStr(dst_rank, ipc_hash, temp_addr, &len));
    // 设置消息结构体
    struct msghdr msg;
    struct iovec iov[2]; // 修改为2,以便发送rank信息和数据
    struct sockaddr_un cliaddr;
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;
    strncpy(cliaddr.sun_path, temp_addr, len);
    cliaddr.sun_path[0] = '\0'; // Linux抽象套接字技巧
    msg.msg_name        = (void*)&cliaddr;
    msg.msg_namelen     = sizeof(cliaddr);
    msg.msg_control     = NULL;
    msg.msg_controllen  = 0;
    msg.msg_flags       = 0;
    // 准备rank信息
    int rank        = localRank;
    iov[0].iov_base = &rank;
    iov[0].iov_len  = sizeof(rank);
    // 准备数据
    iov[1].iov_base = (void*)data;
    iov[1].iov_len  = dataLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 2; // 修改为2
    // 使用 poll 等待 socket 可写
    struct pollfd pfd;
    pfd.fd         = handle->fd;
    pfd.events     = POLLOUT;
    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);

        } else {
843
            WARN("UDS: Error occurred while polling socket %s for writability : %d (%s)", temp_addr, errno, strerror(errno));
844
845
846
847
848
849
        }
        return scclSystemError;
    }
    ssize_t sendResult;
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
850
            WARN("UDS: Error occurred while sending data through socket %s : %d (%s)", temp_addr, errno, strerror(errno));
851
852
853
854
855
856
857
858
859
860
861
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag) {
            return scclInternalError;
        }
        pollResult = poll(&pfd, 1, timeoutMs);
        if(pollResult <= 0) {
            if(pollResult == 0) {
                WARN("UDS: Timeout occurred while waiting to send data to socket %s", temp_addr);

            } else {
862
                WARN("UDS: Error occurred while polling socket %s for writability : %d (%s)", temp_addr, errno, strerror(errno));
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
            }
            return scclSystemError;
        }
    }
    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully sent %zd bytes of data through UDS socket %s", sendResult, temp_addr);
    return scclSuccess;
}

/**
 * @brief 通过IPC套接字发送数据到指定目标rank
 *
 * @param data 要发送的数据指针
 * @param dataLen 要发送的数据长度
 * @param dst_rank 目标rank号
 * @return scclResult_t 返回操作结果状态码:
 *         - scclSuccess: 发送成功
 *         - scclInternalError: 内部错误(如套接字名称过长或中止标志被设置)
 *         - scclSystemError: 系统调用错误(如poll超时或sendmsg失败)
 *
 * @note 使用Linux抽象套接字技术,通过poll机制确保套接字可写后再发送数据
 *       支持EAGAIN/EWOULDBLOCK/EINTR错误重试机制
 */
scclResult_t scclIpcSocket::scclIpcSocketRecvDataAndRank(void* buffer, size_t bufferLen, size_t* receivedLen, int* src_rank) {
    // 设置消息结构体
    struct msghdr msg = {0};
    struct iovec iov[2]; // 修改为2,以便接收rank信息和数据
    int rank;
    iov[0].iov_base = &rank;
    iov[0].iov_len  = sizeof(rank);
    iov[1].iov_base = buffer;
    iov[1].iov_len  = bufferLen;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 2; // 修改为2
    // 使用 poll 等待 socket 可读
    struct pollfd pfd;
    pfd.fd         = handle->fd;
    pfd.events     = POLLIN;
    int pollResult = poll(&pfd, 1, timeoutMs);
    if(pollResult <= 0) {
        if(pollResult == 0) {
            WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);

        } else {
            WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
        }
        return scclSystemError;
    }
    int ret;
    while(true) {
        ret = recvmsg(handle->fd, &msg, 0);
        if(ret > 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Successfully received %d bytes of data from socket %s", ret, handle->socketName);
            *receivedLen = ret - sizeof(rank); // 减去rank信息的长度
            *src_rank    = rank;
            // 设置发送端的rank信息
            return scclSuccess;

        } else if(ret == 0) {
            INFO(SCCL_LOG_BOOTSTRAP, "UDS: Connection closed by peer on socket %s", handle->socketName);
            *receivedLen = 0;
            return scclSuccess;

        } else {
            if(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
                pollResult = poll(&pfd, 1, timeoutMs);
                if(pollResult <= 0) {
                    if(pollResult == 0) {
                        WARN("UDS: Timeout occurred while waiting to receive data from socket %s", handle->socketName);

                    } else {
                        WARN("UDS: Error occurred while polling socket %s for readability : %d", handle->socketName, errno);
                    }
                    return scclSystemError;
                }

            } else {
                WARN("UDS: Error occurred while receiving data through socket %s : %d", handle->socketName, errno);
                return scclSystemError;
            }
        }
    }
}

946
947
948
949
} // namespace ipc_socket
} // namespace net
} // namespace hardware
} // namespace sccl