"scripts/convert_vae_pt_to_diffusers.py" did not exist on "d43972ae71ffbc94dee7045ecfe1e5f7c6ac329e"
Commit a8e057dd authored by Antoine Kaufmann's avatar Antoine Kaufmann
Browse files

dist/rdma: new API refactor

parent 38fc5ec5
...@@ -190,10 +190,12 @@ int BasePeerSetupQueues(struct Peer *peer) { ...@@ -190,10 +190,12 @@ int BasePeerSetupQueues(struct Peer *peer) {
peer->shm_base = shm_base; peer->shm_base = shm_base;
peer->local_base = (void *) ((uintptr_t) shm_base + li->c2l_offset); peer->local_base = (void *) ((uintptr_t) shm_base + li->c2l_offset);
peer->local_offset = li->c2l_offset;
peer->local_elen = li->c2l_elen; peer->local_elen = li->c2l_elen;
peer->local_enum = li->c2l_nentries; peer->local_enum = li->c2l_nentries;
peer->cleanup_base = (void *) ((uintptr_t) shm_base + li->l2c_offset); peer->cleanup_base = (void *) ((uintptr_t) shm_base + li->l2c_offset);
peer->cleanup_offset = li->l2c_offset;
peer->cleanup_elen = li->l2c_elen; peer->cleanup_elen = li->l2c_elen;
peer->cleanup_enum = li->l2c_nentries; peer->cleanup_enum = li->l2c_nentries;
...@@ -335,10 +337,12 @@ int BasePeerEvent(struct Peer *peer, uint32_t events) { ...@@ -335,10 +337,12 @@ int BasePeerEvent(struct Peer *peer, uint32_t events) {
struct SimbricksProtoListenerIntro *li = struct SimbricksProtoListenerIntro *li =
(struct SimbricksProtoListenerIntro *) peer->intro_local; (struct SimbricksProtoListenerIntro *) peer->intro_local;
peer->local_base = (void *) ((uintptr_t) peer->shm_base + li->l2c_offset); peer->local_base = (void *) ((uintptr_t) peer->shm_base + li->l2c_offset);
peer->local_offset = li->l2c_offset;
peer->local_elen = li->l2c_elen; peer->local_elen = li->l2c_elen;
peer->local_enum = li->l2c_nentries; peer->local_enum = li->l2c_nentries;
peer->cleanup_base = (void *) ((uintptr_t) peer->shm_base + li->c2l_offset); peer->cleanup_base = (void *) ((uintptr_t) peer->shm_base + li->c2l_offset);
peer->cleanup_offset = li->c2l_offset;
peer->cleanup_elen = li->c2l_elen; peer->cleanup_elen = li->c2l_elen;
peer->cleanup_enum = li->c2l_nentries; peer->cleanup_enum = li->c2l_nentries;
} else { } else {
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
struct Peer { struct Peer {
/* base address of the local queue we're polling. */ /* base address of the local queue we're polling. */
uint8_t *local_base; uint8_t *local_base;
uint64_t local_offset;
uint32_t local_elen; uint32_t local_elen;
uint32_t local_enum; uint32_t local_enum;
uint32_t local_pos; uint32_t local_pos;
...@@ -55,6 +56,7 @@ struct Peer { ...@@ -55,6 +56,7 @@ struct Peer {
this periodically and keep track of the last communicated position in this periodically and keep track of the last communicated position in
`cleanup_pos_reported`. */ `cleanup_pos_reported`. */
uint8_t *cleanup_base; uint8_t *cleanup_base;
uint64_t cleanup_offset;
uint32_t cleanup_elen; uint32_t cleanup_elen;
uint32_t cleanup_enum; uint32_t cleanup_enum;
// next position to be cleaned up // next position to be cleaned up
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include <sys/mman.h> #include <sys/mman.h>
#include <unistd.h> #include <unistd.h>
#include <simbricks/proto/base.h> #include <simbricks/base/proto.h>
#include "dist/common/utils.h" #include "dist/common/utils.h"
...@@ -58,14 +58,14 @@ static void PrintUsage() { ...@@ -58,14 +58,14 @@ static void PrintUsage() {
fprintf(stderr, fprintf(stderr,
"Usage: net_rdma [OPTIONS] IP PORT\n" "Usage: net_rdma [OPTIONS] IP PORT\n"
" -l: Listen instead of connecting\n" " -l: Listen instead of connecting\n"
" -d DEV-SOCKET: network socket of a device simulator\n" " -L LISTEN-SOCKET: listening socket for a simulator\n"
" -n NET-SOCKET: network socket of a network simulator\n" " -C CONN-SOCKET: connecting socket for a simulator\n"
" -s SHM-PATH: shared memory region path\n" " -s SHM-PATH: shared memory region path\n"
" -S SHM-SIZE: shared memory region size in MB (default 256)\n"); " -S SHM-SIZE: shared memory region size in MB (default 256)\n");
} }
static int ParseArgs(int argc, char *argv[]) { static int ParseArgs(int argc, char *argv[]) {
const char *opts = "ld:n:s:S:D:ip:g:"; const char *opts = "lL:C:s:S:D:ip:g:";
int c; int c;
while ((c = getopt(argc, argv, opts)) != -1) { while ((c = getopt(argc, argv, opts)) != -1) {
...@@ -74,13 +74,13 @@ static int ParseArgs(int argc, char *argv[]) { ...@@ -74,13 +74,13 @@ static int ParseArgs(int argc, char *argv[]) {
mode_listen = true; mode_listen = true;
break; break;
case 'd': case 'L':
if (!NetPeerAdd(optarg, true)) if (!BasePeerAdd(optarg, true))
return 1; return 1;
break; break;
case 'n': case 'C':
if (!NetPeerAdd(optarg, false)) if (!BasePeerAdd(optarg, false))
return 1; return 1;
break; break;
...@@ -134,7 +134,7 @@ static int ParseArgs(int argc, char *argv[]) { ...@@ -134,7 +134,7 @@ static int ParseArgs(int argc, char *argv[]) {
static void *PollThread(void *data) { static void *PollThread(void *data) {
while (true) while (true)
NetPoll(); BasePoll();
return NULL; return NULL;
} }
...@@ -150,7 +150,7 @@ static int IOLoop() { ...@@ -150,7 +150,7 @@ static int IOLoop() {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
struct Peer *peer = evs[i].data.ptr; struct Peer *peer = evs[i].data.ptr;
if (peer && NetPeerEvent(peer, evs[i].events)) if (peer && BasePeerEvent(peer, evs[i].events))
return 1; return 1;
else if (!peer && RdmaEvent()) else if (!peer && RdmaEvent())
return 1; return 1;
...@@ -173,10 +173,10 @@ int main(int argc, char *argv[]) { ...@@ -173,10 +173,10 @@ int main(int argc, char *argv[]) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
if (NetInit(shm_path, shm_size, epfd)) if (BaseInit(shm_path, shm_size, epfd))
return EXIT_FAILURE; return EXIT_FAILURE;
if (NetListen()) if (BaseListen())
return EXIT_FAILURE; return EXIT_FAILURE;
if (mode_listen) { if (mode_listen) {
...@@ -189,7 +189,7 @@ int main(int argc, char *argv[]) { ...@@ -189,7 +189,7 @@ int main(int argc, char *argv[]) {
printf("RDMA connected\n"); printf("RDMA connected\n");
fflush(stdout); fflush(stdout);
if (NetConnect()) if (BaseConnect())
return EXIT_FAILURE; return EXIT_FAILURE;
printf("Peers initialized\n"); printf("Peers initialized\n");
fflush(stdout); fflush(stdout);
......
...@@ -25,13 +25,12 @@ ...@@ -25,13 +25,12 @@
#ifndef DIST_NET_RDMA_H_ #ifndef DIST_NET_RDMA_H_
#define DIST_NET_RDMA_H_ #define DIST_NET_RDMA_H_
#include "dist/common/net.h" #include "dist/common/base.h"
#include <arpa/inet.h> #include <arpa/inet.h>
#include <stdbool.h> #include <stdbool.h>
#include <stddef.h> #include <stddef.h>
#include <simbricks/proto/network.h>
// configuration variables // configuration variables
extern size_t shm_size; extern size_t shm_size;
......
...@@ -39,6 +39,11 @@ ...@@ -39,6 +39,11 @@
#define MAX_PEERS 32 #define MAX_PEERS 32
#define SIG_THRESHOLD 32 #define SIG_THRESHOLD 32
struct NetRdmaIntroMsg {
uint32_t payload_len;
uint8_t data[1024];
} __attribute__((packed));
struct NetRdmaReportMsg { struct NetRdmaReportMsg {
uint32_t written_pos[MAX_PEERS]; uint32_t written_pos[MAX_PEERS];
uint32_t clean_pos[MAX_PEERS]; uint32_t clean_pos[MAX_PEERS];
...@@ -47,8 +52,7 @@ struct NetRdmaReportMsg { ...@@ -47,8 +52,7 @@ struct NetRdmaReportMsg {
struct NetRdmaMsg { struct NetRdmaMsg {
union { union {
struct SimbricksProtoNetDevIntro dev; struct NetRdmaIntroMsg intro;
struct SimbricksProtoNetNetIntro net;
struct NetRdmaReportMsg report; struct NetRdmaReportMsg report;
struct NetRdmaMsg *next_free; struct NetRdmaMsg *next_free;
}; };
...@@ -57,8 +61,7 @@ struct NetRdmaMsg { ...@@ -57,8 +61,7 @@ struct NetRdmaMsg {
uint64_t queue_off; uint64_t queue_off;
uint64_t rkey; uint64_t rkey;
enum { enum {
kMsgDev, kMsgIntro,
kMsgNet,
kMsgReport, kMsgReport,
} msg_type; } msg_type;
} __attribute__((packed)); } __attribute__((packed));
...@@ -115,40 +118,44 @@ static int RdmMsgRxEnqueue(struct NetRdmaMsg *msg) { ...@@ -115,40 +118,44 @@ static int RdmMsgRxEnqueue(struct NetRdmaMsg *msg) {
static int RdmaMsgRxIntro(struct NetRdmaMsg *msg) { static int RdmaMsgRxIntro(struct NetRdmaMsg *msg) {
if (msg->id >= peer_num) { if (msg->id >= peer_num) {
fprintf(stderr, "RdmMsgRx: invalid peer id in message (%lu)\n", msg->id); fprintf(stderr, "RdmaMsgRxIntro: invalid peer id in message (%lu)\n",
msg->id);
abort(); abort();
} }
struct Peer *peer = peers + msg->id; struct Peer *peer = peers + msg->id;
printf("RdmMsgRx -> peer %s\n", peer->sock_path); printf("RdmMsgRx -> peer %s\n", peer->sock_path);
if (peer->is_dev != (msg->msg_type == kMsgNet)) {
fprintf(stderr, "RdmMsgRx: unexpetced message type (%u)\n", msg->msg_type);
abort();
}
if (peer->intro_valid_remote) { if (peer->intro_valid_remote) {
fprintf(stderr, "RdmMsgRx: received multiple messages (%lu)\n", msg->id); fprintf(stderr, "RdmaMsgRxIntro: received multiple messages (%lu)\n",
msg->id);
abort(); abort();
} }
peer->remote_rkey = msg->rkey; peer->remote_rkey = msg->rkey;
peer->remote_base = msg->base_addr + msg->queue_off; peer->remote_base = msg->base_addr + msg->queue_off;
peer->intro_valid_remote = true; peer->intro_valid_remote = true;
if (peer->is_dev) { peer->intro_remote_len = msg->intro.payload_len;
peer->net_intro = msg->net; memcpy(peer->intro_remote, msg->intro.data, msg->intro.payload_len);
if (NetPeerSendDevIntro(peer))
return 1; if (BasePeerSetupQueues(peer)) {
} else { fprintf(stderr, "RdmaMsgRxIntro(%s): queue setup failed\n",
peer->dev_intro = msg->dev; peer->sock_path);
if (NetPeerSetupNetQueues(peer)) abort();
return 1;
if (peer->intro_valid_local && NetOpPassIntro(peer))
return 1;
} }
if (BasePeerSendIntro(peer))
return 1;
if (peer->intro_valid_local) { if (peer->intro_valid_local) {
fprintf(stderr, "RdmMsgRx(%s): marking peer as ready\n", peer->sock_path); // now we can send our intro for a listener
if (peer->is_listener && BaseOpPassIntro(peer)) {
fprintf(stderr, "RdmaMsgRxIntro(%s): sending l intro failed\n",
peer->sock_path);
return 1;
}
fprintf(stderr, "RdmaMsgRxIntro(%s): marking peer as ready\n",
peer->sock_path);
peer->ready = true; peer->ready = true;
} }
return 0; return 0;
...@@ -163,14 +170,14 @@ static int RdmaMsgRxReport(struct NetRdmaMsg *msg) { ...@@ -163,14 +170,14 @@ static int RdmaMsgRxReport(struct NetRdmaMsg *msg) {
fprintf(stderr, "RdmaMsgRxReport: invalid ready peer number %zu\n", i); fprintf(stderr, "RdmaMsgRxReport: invalid ready peer number %zu\n", i);
abort(); abort();
} }
NetPeerReport(&peers[i], msg->report.written_pos[i], BasePeerReport(&peers[i], msg->report.written_pos[i],
msg->report.clean_pos[i]); msg->report.clean_pos[i]);
} }
return 0; return 0;
} }
static int RdmaMsgRx(struct NetRdmaMsg *msg) { static int RdmaMsgRx(struct NetRdmaMsg *msg) {
if (msg->msg_type == kMsgDev || msg->msg_type == kMsgNet) if (msg->msg_type == kMsgIntro)
return RdmaMsgRxIntro(msg); return RdmaMsgRxIntro(msg);
else if (msg->msg_type == kMsgReport) else if (msg->msg_type == kMsgReport)
return RdmaMsgRxReport(msg); return RdmaMsgRxReport(msg);
...@@ -336,7 +343,7 @@ int RdmaEvent() { ...@@ -336,7 +343,7 @@ int RdmaEvent() {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
if (wcs[i].opcode == IBV_WC_SEND) { if (wcs[i].opcode == IBV_WC_SEND) {
#ifdef RDMA_DEBUG #ifdef RDMA_DEBUG
fprintf(stderr, "Send done\n", n); fprintf(stderr, "Send done\n");
#endif #endif
if (wcs[i].status != IBV_WC_SUCCESS) { if (wcs[i].status != IBV_WC_SUCCESS) {
fprintf(stderr, "RdmaEvent: unsuccessful send (%u)\n", wcs[i].status); fprintf(stderr, "RdmaEvent: unsuccessful send (%u)\n", wcs[i].status);
...@@ -347,7 +354,7 @@ int RdmaEvent() { ...@@ -347,7 +354,7 @@ int RdmaEvent() {
RdmaMsgFree(msgs + wcs[i].wr_id); RdmaMsgFree(msgs + wcs[i].wr_id);
} else if ((wcs[i].opcode & IBV_WC_RECV)) { } else if ((wcs[i].opcode & IBV_WC_RECV)) {
#ifdef RDMA_DEBUG #ifdef RDMA_DEBUG
fprintf(stderr, "Recv done\n", n); fprintf(stderr, "Recv done\n");
#endif #endif
if (wcs[i].status != IBV_WC_SUCCESS) { if (wcs[i].status != IBV_WC_SUCCESS) {
...@@ -370,17 +377,17 @@ int RdmaEvent() { ...@@ -370,17 +377,17 @@ int RdmaEvent() {
return 0; return 0;
} }
int NetOpPassIntro(struct Peer *peer) { int BaseOpPassIntro(struct Peer *peer) {
#ifdef RDMA_DEBUG #ifdef RDMA_DEBUG
fprintf(stderr, "NetOpPassIntro(%s)\n", peer->sock_path); fprintf(stderr, "BaseOpPassIntro(%s)\n", peer->sock_path);
#endif #endif
// device peers have sent us an SHM region, need to register this an as MR // connecting peers have sent us an SHM region, need to register this an as MR
if (peer->is_dev) { if (!peer->is_listener) {
if (!(peer->shm_opaque = ibv_reg_mr(pd, peer->shm_base, peer->shm_size, if (!(peer->shm_opaque = ibv_reg_mr(pd, peer->shm_base, peer->shm_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_LOCAL_WRITE |
IBV_ACCESS_REMOTE_WRITE))) { IBV_ACCESS_REMOTE_WRITE))) {
perror("NetOpPassIntro: ibv_reg_mr shm failed"); perror("BaseOpPassIntro: ibv_reg_mr shm failed");
return 1; return 1;
} }
} else { } else {
...@@ -388,13 +395,11 @@ int NetOpPassIntro(struct Peer *peer) { ...@@ -388,13 +395,11 @@ int NetOpPassIntro(struct Peer *peer) {
intro from our RDMA peer, so we can include the queue position. */ intro from our RDMA peer, so we can include the queue position. */
if (!peer->intro_valid_remote) { if (!peer->intro_valid_remote) {
fprintf(stderr, fprintf(stderr,
"NetOpPassIntro: skipping because remote intro not received\n"); "BaseOpPassIntro: skipping because remote intro not received\n");
return 0; return 0;
} }
peer->shm_opaque = mr_shm; peer->shm_opaque = mr_shm;
peer->shm_base = shm_base;
peer->shm_size = shm_size;
} }
struct NetRdmaMsg *msg = RdmaMsgAlloc(); struct NetRdmaMsg *msg = RdmaMsgAlloc();
...@@ -405,19 +410,14 @@ int NetOpPassIntro(struct Peer *peer) { ...@@ -405,19 +410,14 @@ int NetOpPassIntro(struct Peer *peer) {
msg->base_addr = (uintptr_t) peer->shm_base; msg->base_addr = (uintptr_t) peer->shm_base;
struct ibv_mr *mr = peer->shm_opaque; struct ibv_mr *mr = peer->shm_opaque;
msg->rkey = mr->rkey; msg->rkey = mr->rkey;
if (peer->is_dev) { msg->msg_type = kMsgIntro;
msg->msg_type = kMsgDev; msg->queue_off = peer->cleanup_offset;
/* this is a device peer, meaning the remote side will write to the msg->intro.payload_len = peer->intro_local_len;
network-to-device queue. */ if (peer->intro_local_len > sizeof(msg->intro.data)) {
msg->queue_off = peer->dev_intro.n2d_offset; fprintf(stderr, "BaseOpPassIntro: intro longer than buffer\n");
msg->dev = peer->dev_intro; abort();
} else {
msg->msg_type = kMsgNet;
/* this is a network peer, meaning the remote side will write to the
device-to-network queue. */
msg->queue_off = peer->dev_intro.d2n_offset;
msg->net = peer->net_intro;
} }
memcpy(msg->intro.data, peer->intro_local, peer->intro_local_len);
struct ibv_sge sge; struct ibv_sge sge;
sge.addr = (uintptr_t) msg; sge.addr = (uintptr_t) msg;
...@@ -433,19 +433,19 @@ int NetOpPassIntro(struct Peer *peer) { ...@@ -433,19 +433,19 @@ int NetOpPassIntro(struct Peer *peer) {
struct ibv_send_wr *bad_send_wr; struct ibv_send_wr *bad_send_wr;
if (ibv_post_send(qp, &send_wr, &bad_send_wr)) { if (ibv_post_send(qp, &send_wr, &bad_send_wr)) {
perror("RdmaPassIntro: ibv_post_send failed"); perror("BaseOpPassIntro: ibv_post_send failed");
return 1; return 1;
} }
#ifdef RDMA_DEBUG #ifdef RDMA_DEBUG
fprintf(stderr, "RdmaPassIntro: ibv_post_send done\n"); fprintf(stderr, "BaseOpPassIntro: ibv_post_send done\n");
#endif #endif
return 0; return 0;
} }
int NetOpPassEntries(struct Peer *peer, uint32_t pos, uint32_t n) { int BaseOpPassEntries(struct Peer *peer, uint32_t pos, uint32_t n) {
#ifdef RDMA_DEBUG #ifdef RDMA_DEBUG
fprintf(stderr, "NetOpPassEntries(%s,%u)\n", peer->sock_path, fprintf(stderr, "BaseOpPassEntries(%s,%u)\n", peer->sock_path,
pos); pos);
fprintf(stderr, " remote_base=%lx local_base=%p\n", peer->remote_base, fprintf(stderr, " remote_base=%lx local_base=%p\n", peer->remote_base,
peer->local_base); peer->local_base);
...@@ -478,7 +478,7 @@ int NetOpPassEntries(struct Peer *peer, uint32_t pos, uint32_t n) { ...@@ -478,7 +478,7 @@ int NetOpPassEntries(struct Peer *peer, uint32_t pos, uint32_t n) {
if (ret == 0) { if (ret == 0) {
break; break;
} else if (ret != ENOMEM) { } else if (ret != ENOMEM) {
fprintf(stderr, "NetOpPassEntries: ibv_post_send failed %d (%s)\n", ret, fprintf(stderr, "BaseOpPassEntries: ibv_post_send failed %d (%s)\n", ret,
strerror(ret)); strerror(ret));
return 1; return 1;
} }
...@@ -486,9 +486,9 @@ int NetOpPassEntries(struct Peer *peer, uint32_t pos, uint32_t n) { ...@@ -486,9 +486,9 @@ int NetOpPassEntries(struct Peer *peer, uint32_t pos, uint32_t n) {
return 0; return 0;
} }
int NetOpPassReport() { int BaseOpPassReport() {
if (peer_num > MAX_PEERS) { if (peer_num > MAX_PEERS) {
fprintf(stderr, "NetOpPassReport: peer_num (%zu) larger than max (%u)\n", fprintf(stderr, "BaseOpPassReport: peer_num (%zu) larger than max (%u)\n",
peer_num, MAX_PEERS); peer_num, MAX_PEERS);
abort(); abort();
} }
......
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <unistd.h> #include <unistd.h>
#include <simbricks/base/proto.h> #include <simbricks/base/proto.h>
#include <simbricks/network/proto.h>
#include "dist/common/base.h" #include "dist/common/base.h"
#include "dist/common/utils.h" #include "dist/common/utils.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment