ipcsocket.cpp 7.88 KB
Newer Older
lishen's avatar
lishen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include "ipcsocket.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {
// Enable Linux abstract socket naming
#define USE_ABSTRACT_SOCKET

#define SCCL_IPC_SOCKNAME_STR "/tmp/sccl-socket-%d-%lx"

/**
 * @brief 初始化IPC套接字
 *
 * 创建一个UNIX域数据报套接字,并绑定到指定路径。支持抽象套接字和普通文件系统套接字两种模式。
 *
 * @param handle 指向scclIpcSocket结构体的指针,用于存储套接字信息
 * @param rank 进程排名,用于生成唯一的套接字名称
 * @param hash 哈希值,与rank一起用于生成唯一的套接字名称
 * @param abortFlag 指向中止标志的指针,如果非NULL则设置套接字为非阻塞模式
 * @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回相应错误码
 */
scclResult_t scclIpcSocketInit(scclIpcSocket* handle, int rank, uint64_t hash, volatile uint32_t* abortFlag) {
    int fd = -1;
    struct sockaddr_un cliaddr;
    char temp[SCCL_IPC_SOCKNAME_LEN] = "";

    if(handle == NULL) {
        return scclInternalError;
    }

    handle->fd            = -1;
    handle->socketName[0] = '\0';
    if((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
        WARN("UDS: Socket creation error : %d", errno);
        return scclSystemError;
    }

    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;

    // Create unique name for the socket.
    int len = snprintf(temp, SCCL_IPC_SOCKNAME_LEN, SCCL_IPC_SOCKNAME_STR, rank, hash);
    if(len > (sizeof(cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot bind provided name to socket. Name too large");
        return scclInternalError;
    }
#ifndef USE_ABSTRACT_SOCKET
    unlink(temp);
#endif

    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Creating socket %s", temp);

    strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET
    cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif
    if(bind(fd, (struct sockaddr*)&cliaddr, sizeof(cliaddr)) < 0) {
        WARN("UDS: Binding to socket %s failed : %d", temp, errno);
        close(fd);
        return scclSystemError;
    }

    handle->fd = fd;
    strcpy(handle->socketName, temp);

    handle->abortFlag = abortFlag;
    // Mark socket as non-blocking
    if(handle->abortFlag) {
        int flags;
        EQCHECK(flags = fcntl(fd, F_GETFL), -1);
        SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
    }

    return scclSuccess;
}

/**
 * 关闭IPC套接字并释放相关资源
 *
 * @param handle 指向scclIpcSocket结构体的指针,包含要关闭的套接字信息
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 操作成功完成
 *         - scclInternalError: 传入无效句柄(handle为NULL)
 *
 * @note 如果定义了USE_ABSTRACT_SOCKET宏,则不会删除socket文件
 *       如果套接字文件描述符无效(fd<=0),函数会直接返回成功
 */
scclResult_t scclIpcSocketClose(scclIpcSocket* handle) {
    if(handle == NULL) {
        return scclInternalError;
    }
    if(handle->fd <= 0) {
        return scclSuccess;
    }
#ifndef USE_ABSTRACT_SOCKET
    if(handle->socketName[0] != '\0') {
        unlink(handle->socketName);
    }
#endif
    close(handle->fd);

    return scclSuccess;
}

/**
 * 通过IPC socket接收文件描述符
 *
 * @param handle 指向scclIpcSocket结构体的指针,包含socket相关信息
 * @param recvFd 用于存储接收到的文件描述符的指针
 * @return scclResult_t 返回操作结果:
 *         - scclSuccess: 成功接收文件描述符
 *         - scclSystemError: 系统调用出错
 *         - scclInternalError: 操作被中断
 *
 * @note 该函数会阻塞等待直到接收到数据或发生错误
 * @warning 调用者需要确保recvFd指向有效的内存空间
 */
scclResult_t scclIpcSocketRecvFd(scclIpcSocket* handle, int* recvFd) {
    struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
    struct iovec iov[1];

    // Union to guarantee alignment requirements for control array
    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    char dummy_buffer[1];
    int ret;

    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    iov[0].iov_base = (void*)dummy_buffer;
    iov[0].iov_len  = sizeof(dummy_buffer);

    msg.msg_iov    = iov;
    msg.msg_iovlen = 1;

    while((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Receiving data over socket failed : %d", errno);
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    if(((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
        if((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
            WARN("UDS: Receiving data over socket failed");
            return scclSystemError;
        }

        memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
    } else {
        WARN("UDS: Receiving data over socket %s failed", handle->socketName);
        return scclSystemError;
    }

    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName);

    return scclSuccess;
}

/**
 * 通过UNIX域套接字发送文件描述符
 *
 * @param handle      IPC套接字句柄
 * @param sendFd      要发送的文件描述符
 * @param rank        目标rank号
 * @param hash        用于生成套接字名的哈希值
 *
 * @return 成功返回scclSuccess,失败返回错误码:
 *         - scclInternalError: 内部错误(如名称过长或操作被中止)
 *         - scclSystemError:   系统调用错误
 *
 * @note 使用SCM_RIGHTS机制通过控制消息发送文件描述符
 *       在Linux下支持抽象套接字命名空间(当USE_ABSTRACT_SOCKET定义时)
 */
scclResult_t scclIpcSocketSendFd(scclIpcSocket* handle, const int sendFd, int rank, uint64_t hash) {
    struct msghdr msg;
    struct iovec iov[1];
    char temp[SCCL_IPC_SOCKNAME_LEN];

    union {
        struct cmsghdr cm;
        char control[CMSG_SPACE(sizeof(int))];
    } control_un;

    struct cmsghdr* cmptr;
    struct sockaddr_un cliaddr;

    // Construct client address to send this shareable handle to
    bzero(&cliaddr, sizeof(cliaddr));
    cliaddr.sun_family = AF_UNIX;

    int len = snprintf(temp, SCCL_IPC_SOCKNAME_LEN, SCCL_IPC_SOCKNAME_STR, rank, hash);
    if(len > (sizeof(cliaddr.sun_path) - 1)) {
        WARN("UDS: Cannot connect to provided name for socket. Name too large");
        return scclInternalError;
    }
    (void)strncpy(cliaddr.sun_path, temp, len);

    INFO(SCCL_LOG_BOOTSTRAP, "UDS: Sending fd %d to UDS socket %s", sendFd, temp);

#ifdef USE_ABSTRACT_SOCKET
    cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif

    msg.msg_control    = control_un.control;
    msg.msg_controllen = sizeof(control_un.control);

    cmptr             = CMSG_FIRSTHDR(&msg);
    cmptr->cmsg_len   = CMSG_LEN(sizeof(int));
    cmptr->cmsg_level = SOL_SOCKET;
    cmptr->cmsg_type  = SCM_RIGHTS;

    memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd));

    msg.msg_name    = (void*)&cliaddr;
    msg.msg_namelen = sizeof(struct sockaddr_un);

    iov[0].iov_base = (void*)"";
    iov[0].iov_len  = 1;
    msg.msg_iov     = iov;
    msg.msg_iovlen  = 1;
    msg.msg_flags   = 0;

    ssize_t sendResult;
    while((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) {
        if(errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
            WARN("UDS: Sending data over socket %s failed : %d", temp, errno);
            return scclSystemError;
        }
        if(handle->abortFlag && *handle->abortFlag)
            return scclInternalError;
    }

    return scclSuccess;
}

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl