net_socket.h 5.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#pragma once

#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <sys/types.h>
#include <unistd.h>

#include "base.h"
#include "net_utils.h"
#include "socket.h"

namespace sccl {
namespace hardware {
namespace net {
namespace net_socket {

/* Communication functions */
static constexpr int MAX_SOCKETS   = 64;
static constexpr int MAX_THREADS   = 16;
static constexpr int MAX_REQUESTS  = SCCL_NET_MAX_REQUESTS;
static constexpr int MIN_CHUNKSIZE = (64 * 1024);

enum scclNetSocketCommState : uint8_t {
    scclNetSocketCommStateStart   = 0,
    scclNetSocketCommStateConnect = 1,
    scclNetSocketCommStateAccept  = 3,
    scclNetSocketCommStateSend    = 4,
    scclNetSocketCommStateRecv    = 5,
};

struct scclNetSocketCommStage {
    enum scclNetSocketCommState state;
    uint8_t iteration;
    struct scclSocket* sock;
    struct scclNetSocketComm* comm = nullptr;
};

struct scclNetSocketHandle {
    union scclSocketAddress connectAddr;
    uint64_t magic; // random number to help debugging
    int nSocks;
    int nThreads;
    struct scclNetSocketCommStage stage;
};

struct scclNetSocketTask {
    int op;
    void* data;
    int size;
    struct scclSocket* sock = nullptr;
    int offset;
    int used;
    scclResult_t result;
};

struct scclNetSocketRequest {
    int op;
    void* data;
    int size;
    struct scclSocket* ctrlSock = nullptr;
    int offset;
    int used;
    struct scclNetSocketComm* comm               = nullptr;
    struct scclNetSocketTask* tasks[MAX_SOCKETS] = {nullptr};
    int nSubs;
};

struct scclNetSocketTaskQueue {
    int next;
    int len;
    struct scclNetSocketTask* tasks = nullptr;
};

struct scclNetSocketThreadResources {
    struct scclNetSocketTaskQueue threadTaskQueue;
    int stop;
    struct scclNetSocketComm* comm = nullptr;
    pthread_mutex_t threadLock;
    pthread_cond_t threadCond;
};

struct scclNetSocketListenComm {
    struct scclSocket sock;
    struct scclNetSocketCommStage stage;
    int nSocks;
    int nThreads;
    int dev;
};

struct scclNetSocketComm {
    struct scclSocket ctrlSock;
    struct scclSocket socks[MAX_SOCKETS];
    int dev;
    int hipDev;
    int nSocks;
    int nThreads;
    int nextSock;
    struct scclNetSocketRequest requests[MAX_REQUESTS];
    pthread_t helperThread[MAX_THREADS];
    struct scclNetSocketThreadResources threadResources[MAX_THREADS];
};

//////////////////////////////////
class scclNetSocket : public scclNetBase {
public:
    // 构造函数和析构函数
    scclNetSocket();
    virtual ~scclNetSocket();

    // 初始化网络。
    scclResult_t init() override;
    // 返回适配器的数量。
    scclResult_t devices(int* ndev) override;
    // 获取各种设备属性。
    scclResult_t getProperties(int dev, scclNetProperties_t* props) override;
    // 创建一个接收对象并提供一个句柄以连接到它。该句柄最多可以是 SCCL_NET_HANDLE_MAXSIZE 字节,并将在排名之间交换以创建连接。
    scclResult_t listen(int dev, void* handle, void** listenComm) override;
    // 连接到一个句柄并返回一个发送 comm 对象给该对等体。
    // 此调用不应阻塞以建立连接,而应成功返回 sendComm == NULL,并期望再次调用直到 sendComm != NULL。
    scclResult_t connect(int dev, void* handle, void** sendComm) override;
    // 在远程对等体调用 connect 后最终确定连接建立。
    // 此调用不应阻塞以建立连接,而应成功返回 recvComm == NULL,并期望再次调用直到 recvComm != NULL。
    scclResult_t accept(void* listenComm, void** recvComm) override;
    // 注册/注销内存。Comm 可以是 sendComm 或 recvComm。
    // 类型是 SCCL_PTR_HOST 或 SCCL_PTR_CUDA。
    scclResult_t regMr(void* comm, void* data, int size, int type, void** mhandle) override;
    /* DMA-BUF 支持 */
    scclResult_t regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) override;
    // 注销IB内存区域(MR)
    scclResult_t deregMr(void* comm, void* mhandle) override;
    // 异步发送到对等体。
    // 如果调用不能执行(或会阻塞),则可能返回 request == NULL
    scclResult_t isend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) override;
    // 异步从对等体接收。 如果调用不能执行(或会阻塞),则可能返回 request == NULL
    scclResult_t irecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) override;
    // 执行刷新/栅栏操作,以确保所有使用 SCCL_PTR_CUDA 接收到的数据对 GPU 可见
    scclResult_t iflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) override;
    // 测试请求是否完成。如果 size 不为 NULL,则返回发送/接收的字节数。
    scclResult_t test(void* request, int* done, int* sizes) override;
    // 关闭并释放 send/recv comm 对象
    scclResult_t closeSend(void* sendComm) override;
    scclResult_t closeRecv(void* recvComm) override;
    scclResult_t closeListen(void* listenComm) override;

private:
    struct scclNetSocketListenComm* socketComm = nullptr;

private:
    // 获取网络设备的PCI路径
    static scclResult_t scclNetSocketGetPciPath(char* devName, char** pciPath);
    // 获取指定网络设备的速度(单位:Mbps)
    scclResult_t scclNetSocketGetSpeed(char* devName, int* speed);
    // 持久化socket线程处理函数
    static void* persistentSocketThread(void* args_);
    // 为指定通信对象创建并获取一个网络套接字任务
    scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req);
    // 获取指定设备的socket和线程数量配置
    scclResult_t scclNetSocketGetNsockNthread(int dev, int* ns, int* nt);
};

} // namespace net_socket
} // namespace net
} // namespace hardware
} // namespace sccl