Commit a4ac3320 authored by lishen's avatar lishen
Browse files

通过线程池实现ipcsocket,满足节点内通信

parent d9d23f34
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <vector> // 引入vector库
#include <thread> // 为了使用 std::this_thread::sleep_for
#include "mpi.h"
#include "net.h"
#include "ipc_socket.h"
#include "thread_pool.h"
using namespace sccl;
typedef class sccl::hardware::net::ipc_socket::scclIpcSocket scclIpcSocket_t;
template <typename T>
void send_data(T* ipcsocket, const void* data, size_t dataLen, int dst_rank, uint64_t dst_hash) {
if(ipcsocket->scclIpcSocketSendData(data, dataLen, dst_rank, dst_hash) != scclSuccess) {
perror("Failed to send data");
MPI_Abort(MPI_COMM_WORLD, 1);
}
}
template <typename T>
void recv_data(T* ipcsocket, void* buffer, size_t bufferLen, size_t* receivedLen) {
if(ipcsocket->scclIpcSocketRecvData(buffer, bufferLen, receivedLen) != scclSuccess) {
perror("Failed to receive data");
MPI_Abort(MPI_COMM_WORLD, 1);
}
}
int main(int argc, char* argv[]) {
MPI_Init(&argc, &argv);
int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
int dst_hash = 12345;
scclIpcSocket_t ipcsocket(rank, dst_hash);
int sendDataLen = 256;
std::vector<char> sendData(sendDataLen);
std::vector<char> recvData(size * sendDataLen);
size_t receivedLen;
// 填充发送数据
snprintf(sendData.data(), sendData.size(), "Data from process %d", rank);
auto pthpool = ThreadPool(size * 2);
// 发送数据给所有其他进程
for(int i = 0; i < size; ++i) {
if(i != rank) {
auto task_send = std::bind(send_data<scclIpcSocket_t>, &ipcsocket, sendData.data(), sendData.size(), i, dst_hash);
pthpool.enqueue(task_send);
auto task_recv = std::bind(recv_data<scclIpcSocket_t>, &ipcsocket, recvData.data() + i * sendDataLen, sendDataLen, &receivedLen);
pthpool.enqueue(task_recv);
}
}
printf("sendData.size()=%d, receivedLen=%d\n", sendDataLen, int(receivedLen));
std::this_thread::sleep_for(std::chrono::seconds(2));
// 打印接收到的数据
for(int i = 0; i < size; ++i) {
printf("Process %d received from process %d: %s\n", rank, i, recvData.data() + i * 256);
}
MPI_Finalize();
return 0;
}
/*
单机执行
SCCL_DEBUG_LEVEL=ABORT SCCL_DEBUG_SUBSYS=BOOTSTRAP mpirun --allow-run-as-root -np 8 3_socket_mpi_data
*/
hipcc ./1_socket_mpi_fd.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/hardware_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ipc_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/thread_pool.cpp \
-o 1_socket_mpi_fd \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ -Wno-return-type \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm \
-L /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/lib -lmpi \
-L /opt/dtk/lib -lamdhip64 -lrocm-core -lrocm_smi64 -pthread
hipcc ./2_socket_mpi_fd_pthpool.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/hardware_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ipc_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/thread_pool.cpp \
-o 2_socket_mpi_fd_pthpool \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ -Wno-return-type \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm \
-L /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/lib -lmpi \
-L /opt/dtk/lib -lamdhip64 -lrocm-core -lrocm_smi64 -pthread
hipcc ./3_socket_mpi_data.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/hardware_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ipc_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/thread_pool.cpp \
-o 3_socket_mpi_data \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ -Wno-return-type \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm \
-L /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/lib -lmpi \
-L /opt/dtk/lib -lamdhip64 -lrocm-core -lrocm_smi64 -pthread
#include <iostream>
#include <string>
#include <cstring>
#include <unistd.h>
#include <arpa/inet.h>
void start_client(const std::string& server_ip, int server_port) {
int sock = 0;
struct sockaddr_in serv_addr;
char buffer[1024] = {0};
std::string message = "你好,服务器!";
// 创建 socket 文件描述符
if((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
std::cerr << "Socket creation error" << std::endl;
exit(EXIT_FAILURE);
}
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(server_port);
// 转换 IPv4 和 IPv6 地址
if(inet_pton(AF_INET, server_ip.c_str(), &serv_addr.sin_addr) <= 0) {
std::cerr << "Invalid address/ Address not supported" << std::endl;
close(sock);
exit(EXIT_FAILURE);
}
// 连接到服务器
if(connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) {
std::cerr << "Connection Failed" << std::endl;
close(sock);
exit(EXIT_FAILURE);
}
// 发送数据
send(sock, message.c_str(), message.length(), 0);
std::cout << "消息已发送" << std::endl;
// 接收响应
int valread = read(sock, buffer, 1024);
std::cout << "收到的响应: " << buffer << std::endl;
// 关闭连接
close(sock);
}
int main() {
std::string server_ip = "10.16.1.37";
int server_port = 6842;
start_client(server_ip, server_port);
return 0;
}
\ No newline at end of file
#include <iostream>
#include <ifaddrs.h>
#include <arpa/inet.h>
#include <net/if.h>
#include <stdlib.h>
#include <netdb.h>
#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include <unistd.h>
#include <sys/syscall.h>
#define NI_MAXHOST 1025
void get_ip_addresses() {
struct ifaddrs *ifaddr, *ifa;
char host[NI_MAXHOST];
if(getifaddrs(&ifaddr) == -1) {
perror("getifaddrs");
exit(EXIT_FAILURE);
}
for(ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
if(ifa->ifa_addr == NULL)
continue;
if(ifa->ifa_addr->sa_family == AF_INET) { // 检查是否为 IPv4 地址
(void)getnameinfo(ifa->ifa_addr, sizeof(struct sockaddr_in), host, NI_MAXHOST, NULL, 0, NI_NUMERICHOST);
std::cout << "Interface: " << ifa->ifa_name << " Address: " << host << std::endl;
}
}
freeifaddrs(ifaddr);
}
int main() {
get_ip_addresses();
return 0;
}
\ No newline at end of file
#include <iostream>
#include <string>
#include <cstring>
#include <unistd.h>
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
void start_server() {
int server_fd, new_socket;
struct sockaddr_in address;
int addrlen = sizeof(address);
char buffer[1024] = {0};
std::string message = "消息已收到";
// 创建 socket 文件描述符
if((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
perror("socket failed");
exit(EXIT_FAILURE);
}
// 绑定地址和端口
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY; // 自动获取所有 IP 地址
address.sin_port = htons(6842);
if(bind(server_fd, (struct sockaddr*)&address, sizeof(address)) < 0) {
perror("bind failed");
close(server_fd);
exit(EXIT_FAILURE);
}
// 获取绑定的端口号
socklen_t len = sizeof(address);
if(getsockname(server_fd, (struct sockaddr*)&address, &len) == -1) {
perror("getsockname failed");
close(server_fd);
exit(EXIT_FAILURE);
}
int port = ntohs(address.sin_port);
std::cout << "服务器已启动,端口: " << port << std::endl;
// 监听连接
if(listen(server_fd, 3) < 0) {
perror("listen");
close(server_fd);
exit(EXIT_FAILURE);
}
std::cout << "等待连接..." << std::endl;
// 接受客户端连接
if((new_socket = accept(server_fd, (struct sockaddr*)&address, (socklen_t*)&addrlen)) < 0) {
perror("accept");
close(server_fd);
exit(EXIT_FAILURE);
}
while(true) {
// 接收数据
int valread = read(new_socket, buffer, 1024);
if(valread == 0) {
break;
}
std::cout << "收到的消息: " << buffer << std::endl;
send(new_socket, message.c_str(), message.length(), 0);
memset(buffer, 0, sizeof(buffer));
}
// 关闭连接
close(new_socket);
close(server_fd);
}
int main() {
start_server();
return 0;
}
\ No newline at end of file
#include "socket.h"
#include <stdlib.h>
#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include <unistd.h>
#include <sys/syscall.h>
using namespace sccl;
static std::vector<std::pair<int, std::unordered_set<std::string>>> clientPortPool;
static scclResult_t socketProgressOpt(int op, struct scclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) {
int bytes = 0;
*closed = 0;
char* data = (char*)ptr;
char line[SOCKET_NAME_MAXLEN + 1];
do {
if(op == SCCL_SOCKET_RECV)
bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
if(op == SCCL_SOCKET_SEND)
bytes = send(sock->fd, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
if(op == SCCL_SOCKET_RECV && bytes == 0) {
*closed = 1;
return scclSuccess;
}
if(bytes == -1) {
if(errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
WARN("socketProgressOpt: Call to recv from %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
return scclRemoteError;
} else {
bytes = 0;
}
}
(*offset) += bytes;
if(sock->abortFlag && *sock->abortFlag != 0) {
INFO(SCCL_LOG_CODEALL, "socketProgressOpt: abort called");
return scclInternalError;
}
} while(bytes > 0 && (*offset) < size);
return scclSuccess;
}
static scclResult_t socketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
int closed;
SCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed));
if(closed) {
char line[SOCKET_NAME_MAXLEN + 1];
WARN("socketProgress: Connection closed by remote peer %s", scclSocketToString(&sock->addr, line, 0));
return scclRemoteError;
}
return scclSuccess;
}
static scclResult_t socketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
while(*offset < size)
SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
return scclSuccess;
}
/* Format a string representation of a (union scclSocketAddress *) socket address using getnameinfo()
*
* Output: "IPv4/IPv6 address<port>"
*/
const char* scclSocketToString(union scclSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) {
if(buf == NULL || addr == NULL)
return NULL;
struct sockaddr* saddr = &addr->sa;
if(saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
buf[0] = '\0';
return buf;
}
char host[NI_MAXHOST], service[NI_MAXSERV];
/* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned.
* (When not set, this will still happen in case the node's name cannot be determined.)
*/
int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0);
(void)getnameinfo(saddr, sizeof(union scclSocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag);
sprintf(buf, "%s<%s>", host, service);
return buf;
}
static uint16_t socketToPort(union scclSocketAddress* addr) {
struct sockaddr* saddr = &addr->sa;
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
}
/* Allow the user to force the IPv4/IPv6 interface selection */
static int envSocketFamily(void) {
int family = -1; // Family selection is not forced, will use first one found
char* env = getenv("SCCL_SOCKET_FAMILY");
if(env == NULL)
return family;
INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_FAMILY set by environment to %s", env);
if(strcmp(env, "AF_INET") == 0)
family = AF_INET; // IPv4
else if(strcmp(env, "AF_INET6") == 0)
family = AF_INET6; // IPv6
return family;
}
static int findInterfaces(const char* prefixList, char* names, union scclSocketAddress* addrs, int sock_family, int maxIfNameSize, int maxIfs) {
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
if(searchNot)
prefixList++;
bool searchExact = prefixList && prefixList[0] == '=';
if(searchExact)
prefixList++;
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for(interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
if(interface->ifa_addr == NULL)
continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if(family != AF_INET && family != AF_INET6)
continue;
/* Allow the caller to force the socket family type */
if(sock_family != -1 && family != sock_family)
continue;
/* We also need to skip IPv6 loopback interfaces */
if(family == AF_INET6) {
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
if(IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr))
continue;
}
// check against user specified interfaces
if(!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
// Check that this interface has not already been saved
// getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
bool duplicate = false;
for(int i = 0; i < found; i++) {
if(strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) {
duplicate = true;
break;
}
}
if(!duplicate) {
// Store the interface name
strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize);
// Store the IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memcpy(addrs + found, interface->ifa_addr, salen);
found++;
}
}
freeifaddrs(interfaces);
return found;
}
static bool matchSubnet(struct ifaddrs local_if, union scclSocketAddress* remote) {
/* Check family first */
int family = local_if.ifa_addr->sa_family;
if(family != remote->sa.sa_family) {
return false;
}
if(family == AF_INET) {
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
struct sockaddr_in& remote_addr = remote->sin;
struct in_addr local_subnet, remote_subnet;
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
} else if(family == AF_INET6) {
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
struct sockaddr_in6& remote_addr = remote->sin6;
struct in6_addr& local_in6 = local_addr->sin6_addr;
struct in6_addr& mask_in6 = mask->sin6_addr;
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
bool same = true;
int len = 16; // IPv6 address is 16 unsigned char
for(int c = 0; c < len; c++) { // Network byte order is big-endian
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
if(c1 ^ c2) {
same = false;
break;
}
}
// At last, we need to compare scope id
// Two Link-type addresses can have the same subnet address even though they are not in the same scope
// For Global type, this field is 0, so a comparison wouldn't matter
same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
return same;
} else {
WARN("Net : Unsupported address family type");
return false;
}
}
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
char line_a[SOCKET_NAME_MAXLEN + 1];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for(interface = interfaces; interface && !found; interface = interface->ifa_next) {
if(interface->ifa_addr == NULL)
continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if(family != AF_INET && family != AF_INET6)
continue;
// check against user specified interfaces
if(!matchSubnet(*interface, remoteAddr)) {
continue;
}
// Store the local IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memcpy(localAddrs + found, interface->ifa_addr, salen);
// Store the interface name
strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
found++;
if(found == maxIfs)
break;
}
if(found == 0) {
WARN("Net : No interface found in the same subnet as remote address %s", scclSocketToString(remoteAddr, line_a));
}
freeifaddrs(interfaces);
return found;
}
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair) {
if(!(ip_port_pair && strlen(ip_port_pair) > 1)) {
WARN("Net : string is null");
return scclInvalidArgument;
}
bool ipv6 = ip_port_pair[0] == '[';
/* Construct the sockaddress structure */
if(!ipv6) {
struct netIf ni;
// parse <ip_or_hostname>:<port> string, expect one pair
if(parseStringList(ip_port_pair, &ni, 1) != 1) {
WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
return scclInvalidArgument;
}
struct addrinfo hints, *p;
int rv;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
return scclInvalidArgument;
}
// use the first
if(p->ai_family == AF_INET) {
struct sockaddr_in& sin = ua->sin;
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
sin.sin_family = AF_INET; // IPv4
// inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
sin.sin_port = htons(ni.port); // port
} else if(p->ai_family == AF_INET6) {
struct sockaddr_in6& sin6 = ua->sin6;
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
sin6.sin6_family = AF_INET6; // IPv6
sin6.sin6_port = htons(ni.port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = 0; // should be global scope, set to 0
} else {
WARN("Net : unsupported IP family");
return scclInvalidArgument;
}
freeaddrinfo(p); // all done with this structure
} else {
int i, j = -1, len = strlen(ip_port_pair);
for(i = 1; i < len; i++) {
if(ip_port_pair[i] == '%')
j = i;
if(ip_port_pair[i] == ']')
break;
}
if(i == len) {
WARN("Net : No valid [IPv6]:port pair found");
return scclInvalidArgument;
}
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
memset(ip_str, '\0', sizeof(ip_str));
memset(port_str, '\0', sizeof(port_str));
memset(if_name, '\0', sizeof(if_name));
strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
int port = atoi(port_str);
if(!global_scope)
strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name
struct sockaddr_in6& sin6 = ua->sin6;
sin6.sin6_family = AF_INET6; // IPv6
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
sin6.sin6_port = htons(port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
}
return scclSuccess;
}
int scclFindInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) {
static int shownIfName = 0;
int nIfs = 0;
// Allow user to force the INET socket family selection
int sock_family = envSocketFamily();
// User specified interface
char* env = getenv("SCCL_SOCKET_IFNAME");
if(env && strlen(env) > 1) {
INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_IFNAME set by environment to %s", env);
// Specified by user : find or fail
if(shownIfName++ == 0)
INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_IFNAME set to %s", env);
nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
} else {
// Try to automatically pick the right one
// Start with IB
nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// else see if we can get some hint from COMM ID
if(nIfs == 0) {
char* commId = getenv("SCCL_COMM_ID");
if(commId && strlen(commId) > 1) {
INFO(SCCL_LOG_CODEALL, "SCCL_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
union scclSocketAddress idAddr;
scclSocketGetAddrFromString(&idAddr, commId);
nIfs = scclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
}
}
// Then look for anything else (but not docker or lo)
if(nIfs == 0)
nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// Finally look for docker, then lo.
if(nIfs == 0)
nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
if(nIfs == 0)
nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
}
return nIfs;
}
scclResult_t scclSocketListen(struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketListen: pass NULL socket");
return scclInvalidArgument;
}
if(sock->fd == -1) {
WARN("scclSocketListen: file descriptor is -1");
return scclInvalidArgument;
}
if(socketToPort(&sock->addr)) {
// Port is forced by env. Make sure we get the port.
int opt = 1;
#if defined(SO_REUSEPORT)
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
#else
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
#endif
}
// addr port should be 0 (Any port)
SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind");
/* Get the assigned Port */
socklen_t size = sock->salen;
SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname");
/* Put the socket in listen mode
* NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
*/
SYSCHECK(listen(sock->fd, 16384), "listen");
sock->state = scclSocketStateReady;
return scclSuccess;
}
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr) {
if(sock == NULL) {
WARN("scclSocketGetAddr: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady)
return scclInternalError;
memcpy(addr, &sock->addr, sizeof(union scclSocketAddress));
return scclSuccess;
}
static scclResult_t socketTryAccept(struct scclSocket* sock) {
socklen_t socklen = sizeof(union scclSocketAddress);
sock->fd = accept(sock->acceptFd, &sock->addr.sa, &socklen);
if(sock->fd != -1) {
sock->state = scclSocketStateAccepted;
} else if(errno != EAGAIN && errno != EWOULDBLOCK) {
WARN("socketTryAccept: Accept failed: %s", strerror(errno));
return scclSystemError;
}
return scclSuccess;
}
static scclResult_t socketFinalizeAccept(struct scclSocket* sock) {
uint64_t magic;
enum scclSocketType type;
int received = 0;
const int one = 1;
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
if(received == 0)
return scclSuccess;
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
if(magic != sock->magic) {
WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic);
close(sock->fd);
sock->fd = -1;
// Ignore spurious connection and accept again
sock->state = scclSocketStateAccepting;
return scclSuccess;
} else {
received = 0;
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &type, sizeof(type), &received));
if(type != sock->type) {
WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type);
sock->state = scclSocketStateError;
close(sock->fd);
sock->fd = -1;
return scclInternalError;
} else {
sock->state = scclSocketStateReady;
}
}
return scclSuccess;
}
static scclResult_t socketStartConnect(struct scclSocket* sock) {
/* blocking/non-blocking connect() is determined by asyncFlag. */
int ret = connect(sock->fd, &sock->addr.sa, sock->salen);
if(ret == 0) {
sock->state = scclSocketStateConnected;
return scclSuccess;
} else if(errno == EINPROGRESS) {
sock->state = scclSocketStateConnectPolling;
return scclSuccess;
} else if(errno == ECONNREFUSED) {
if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
sock->state = scclSocketStateError;
WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
if(sock->refusedRetries % 1000 == 0)
INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
return scclSuccess;
} else if(errno == ETIMEDOUT) {
if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
sock->state = scclSocketStateError;
WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
return scclSuccess;
} else {
char line[SOCKET_NAME_MAXLEN + 1];
sock->state = scclSocketStateError;
WARN("socketStartConnect: Connect to %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
return scclSystemError;
}
}
static scclResult_t socketPollConnect(struct scclSocket* sock) {
struct pollfd pfd;
int timeout = 1, ret;
socklen_t rlen = sizeof(int);
memset(&pfd, 0, sizeof(struct pollfd));
pfd.fd = sock->fd;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, timeout);
if(ret == 0 || (ret < 0 && errno == EINTR)) {
return scclSuccess;
} else if(ret < 0) {
WARN("socketPollConnect poll() failed with error %s", strerror(errno));
return scclRemoteError;
} else {
EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0);
}
/* check socket status */
SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt");
if(ret == 0) {
sock->state = scclSocketStateConnected;
} else if(ret == ECONNREFUSED) {
if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
sock->state = scclSocketStateError;
WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries);
return scclRemoteError;
}
if(sock->refusedRetries % 1000 == 0)
INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
usleep(SLEEP_INT);
sock->state = scclSocketStateConnecting;
} else if(ret == ETIMEDOUT) {
if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
sock->state = scclSocketStateError;
WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
sock->state = scclSocketStateConnecting;
} else if(ret != EINPROGRESS) {
sock->state = scclSocketStateError;
return scclSystemError;
}
return scclSuccess;
}
scclResult_t scclSocketPollConnect(struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketPollConnect: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socketPollConnect(sock));
return scclSuccess;
}
static scclResult_t socketFinalizeConnect(struct scclSocket* sock) {
int sent = 0;
SCCLCHECK(socketProgress(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
if(sent == 0)
return scclSuccess;
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
sent = 0;
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
sock->state = scclSocketStateReady;
return scclSuccess;
}
static scclResult_t socketProgressState(struct scclSocket* sock) {
if(sock->state == scclSocketStateAccepting) {
SCCLCHECK(socketTryAccept(sock));
}
if(sock->state == scclSocketStateAccepted) {
SCCLCHECK(socketFinalizeAccept(sock));
}
if(sock->state == scclSocketStateConnecting) {
SCCLCHECK(socketStartConnect(sock));
}
if(sock->state == scclSocketStateConnectPolling) {
SCCLCHECK(socketPollConnect(sock));
}
if(sock->state == scclSocketStateConnected) {
SCCLCHECK(socketFinalizeConnect(sock));
}
return scclSuccess;
}
scclResult_t scclSocketReady(struct scclSocket* sock, int* running) {
if(sock == NULL) {
*running = 0;
return scclSuccess;
}
if(sock->state == scclSocketStateError || sock->state == scclSocketStateClosed) {
WARN("scclSocketReady: unexpected socket state %d", sock->state);
return scclRemoteError;
}
*running = (sock->state == scclSocketStateReady) ? 1 : 0;
if(*running == 0) {
SCCLCHECK(socketProgressState(sock));
*running = (sock->state == scclSocketStateReady) ? 1 : 0;
}
return scclSuccess;
}
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse) {
char line[SOCKET_NAME_MAXLEN + 1];
const int one = 1;
if(sock == NULL) {
WARN("scclSocketConnect: pass NULL socket");
return scclInvalidArgument;
}
if(sock->fd == -1) {
WARN("scclSocketConnect: file descriptor is -1");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateInitialized) {
WARN("scclSocketConnect: wrong socket state %d", sock->state);
if(sock->state == scclSocketStateError)
return scclRemoteError;
return scclInternalError;
}
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
if(portReuse) {
int family = sock->addr.sa.sa_family;
if(family != AF_INET && family != AF_INET6) {
WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
scclSocketToString(&sock->addr, line),
family,
AF_INET,
AF_INET6);
return scclInternalError;
}
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in)
: sizeof(struct sockaddr_in6); // pre-define ports according to tid, to avoid extra lock for race condition
if(clientPortPool.size() == 0) {
for(int tid = syscall(SYS_gettid), i = 1; i < 5; i++) {
clientPortPool.push_back(std::make_pair(60000 + i * 1000 + tid % 1000, std::unordered_set<std::string>()));
}
}
// find a port without conflict (different remote peer) in best effort
int reused_port = -1;
std::string remote_peer(scclSocketToString(&sock->addr, line));
for(auto& port : clientPortPool) {
if(port.second.find(remote_peer) == port.second.end()) {
reused_port = port.first;
port.second.insert(remote_peer);
break;
}
}
// bind the port in fd for connect system call
if(reused_port != -1) {
int opt = 1;
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
struct sockaddr_in sin;
sin.sin_family = family;
sin.sin_addr.s_addr = htonl(INADDR_ANY);
sin.sin_port = htons(reused_port);
SYSCHECK(bind(sock->fd, (struct sockaddr*)&sin, salen), "bind_client_port");
}
}
sock->state = scclSocketStateConnecting;
do {
SCCLCHECK(socketProgressState(sock));
} while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
(sock->state == scclSocketStateConnecting || sock->state == scclSocketStateConnectPolling || sock->state == scclSocketStateConnected));
if(sock->abortFlag && *sock->abortFlag != 0)
return scclInternalError;
switch(sock->state) {
case scclSocketStateConnecting:
case scclSocketStateConnectPolling:
case scclSocketStateConnected:
case scclSocketStateReady: return scclSuccess;
case scclSocketStateError: return scclSystemError;
default: WARN("scclSocketConnect: wrong socket state %d", sock->state); return scclInternalError;
}
}
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* listenSock) {
scclResult_t ret = scclSuccess;
if(listenSock == NULL || sock == NULL) {
WARN("scclSocketAccept: pass NULL socket");
ret = scclInvalidArgument;
goto exit;
}
if(listenSock->state != scclSocketStateReady) {
WARN("scclSocketAccept: wrong socket state %d", listenSock->state);
if(listenSock->state == scclSocketStateError)
ret = scclSystemError;
else
ret = scclInternalError;
goto exit;
}
if(sock->acceptFd == -1) {
memcpy(sock, listenSock, sizeof(struct scclSocket));
sock->acceptFd = listenSock->fd;
sock->state = scclSocketStateAccepting;
}
do {
SCCLCHECKGOTO(socketProgressState(sock), ret, exit);
} while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
(sock->state == scclSocketStateAccepting || sock->state == scclSocketStateAccepted));
if(sock->abortFlag && *sock->abortFlag != 0)
return scclInternalError;
switch(sock->state) {
case scclSocketStateAccepting:
case scclSocketStateAccepted:
case scclSocketStateReady: ret = scclSuccess; break;
case scclSocketStateError: ret = scclSystemError; break;
default:
WARN("scclSocketAccept: wrong socket state %d", sock->state);
ret = scclInternalError;
break;
}
exit:
return ret;
}
scclResult_t
scclSocketInit(struct scclSocket* sock, union scclSocketAddress* addr, uint64_t magic, enum scclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) {
scclResult_t ret = scclSuccess;
if(sock == NULL)
goto exit;
sock->timedOutRetries = 0;
sock->refusedRetries = 0;
sock->abortFlag = abortFlag;
sock->asyncFlag = asyncFlag;
sock->state = scclSocketStateInitialized;
sock->magic = magic;
sock->type = type;
sock->fd = -1;
sock->acceptFd = -1;
if(addr) {
/* IPv4/IPv6 support */
int family;
memcpy(&sock->addr, addr, sizeof(union scclSocketAddress));
family = sock->addr.sa.sa_family;
if(family != AF_INET && family != AF_INET6) {
char line[SOCKET_NAME_MAXLEN + 1];
WARN("scclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
scclSocketToString(&sock->addr, line),
family,
AF_INET,
AF_INET6);
ret = scclInternalError;
goto fail;
}
sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
/* Connect to a hostname / port */
sock->fd = socket(family, SOCK_STREAM, 0);
if(sock->fd == -1) {
WARN("scclSocketInit: Socket creation failed : %s", strerror(errno));
ret = scclSystemError;
goto fail;
}
} else {
memset(&sock->addr, 0, sizeof(union scclSocketAddress));
}
/* Set socket as non-blocking if async or if we need to be able to abort */
if((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) {
int flags;
EQCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), -1, ret, fail);
SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), ret, fail);
}
exit:
return ret;
fail:
goto exit;
}
scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
if(sock == NULL) {
WARN("scclSocketProgress: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
return scclSuccess;
}
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
if(sock == NULL) {
WARN("scclSocketWait: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socketWait(op, sock, ptr, size, offset));
return scclSuccess;
}
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketSend: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady) {
WARN("scclSocketSend: socket state (%d) is not ready", sock->state);
return scclInternalError;
}
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, ptr, size, &offset));
return scclSuccess;
}
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketRecv: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady) {
WARN("scclSocketRecv: socket state (%d) is not ready", sock->state);
return scclInternalError;
}
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, ptr, size, &offset));
return scclSuccess;
}
// Receive or detect connection closed
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketTryRecv: pass NULL socket");
return scclInvalidArgument;
}
*closed = 0;
// Block until connection closes or nbytes received
if(blocking) {
while(offset < size) {
SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
}
} else {
SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
// If any bytes were received, block waiting for the rest
if(offset > 0) {
while(offset < size) {
SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
}
// No bytes were received, return scclInProgress
} else {
return scclInProgress;
}
}
return scclSuccess;
}
scclResult_t scclSocketClose(struct scclSocket* sock) {
if(sock != NULL) {
if(sock->fd >= 0) {
/* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected
* by refcount of fd, but close() is. close() won't close a fd and send FIN packet if
* the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful
* connection close here. */
shutdown(sock->fd, SHUT_RDWR);
close(sock->fd);
}
sock->state = scclSocketStateClosed;
sock->fd = -1;
}
return scclSuccess;
}
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd) {
if(sock == NULL) {
WARN("scclSocketGetFd: pass NULL socket");
return scclInvalidArgument;
}
if(fd)
*fd = sock->fd;
return scclSuccess;
}
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketGetFd: pass NULL socket");
return scclInvalidArgument;
}
sock->fd = fd;
return scclSuccess;
}
#pragma once
#include "debug.h"
#include "check.h"
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>
using namespace sccl;
struct netIf {
char prefix[64];
int port;
};
static thread_local int scclDebugNoWarn = 0;
#define SYSCHECK(call, name) \
do { \
int retval; \
SYSCHECKVAL(call, name, retval); \
} while(false)
#define SYSCHECKVAL(call, name, retval) \
do { \
SYSCHECKSYNC(call, name, retval); \
if(retval == -1) { \
WARN("Call to " name " failed : %s", strerror(errno)); \
return scclSystemError; \
} \
} while(false)
#define SYSCHECKSYNC(call, name, retval) \
do { \
retval = call; \
if(retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
INFO(SCCL_LOG_CODEALL, "Call to " name " returned %s, retrying", strerror(errno)); \
} else { \
break; \
} \
} while(true)
#define EQCHECK(statement, value) \
do { \
if((statement) == value) { \
/* Print the back trace*/ \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, scclSystemError, strerror(errno)); \
return scclSystemError; \
} \
} while(0);
#define NEQCHECKGOTO(statement, value, RES, label) \
do { \
if((statement) != value) { \
/* Print the back trace*/ \
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
#define SYSCHECKGOTO(statement, RES, label) \
do { \
if((statement) == -1) { \
/* Print the back trace*/ \
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
#define SCCLCHECKGOTO(call, RES, label) \
do { \
RES = call; \
if(RES != scclSuccess && RES != scclInProgress) { \
/* Print the back trace*/ \
if(scclDebugNoWarn == 0) \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d", __FILE__, __LINE__, RES); \
goto label; \
} \
INFO(SCCL_LOG_CODEALL, "check pass %s:%d -> %d", __FILE__, __LINE__, RES); \
} while(0);
#define EQCHECKGOTO(statement, value, RES, label) \
do { \
if((statement) == value) { \
/* Print the back trace*/ \
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
static int parseStringList(const char* string, struct netIf* ifList, int maxList) {
if(!string)
return 0;
const char* ptr = string;
int ifNum = 0;
int ifC = 0;
char c;
do {
c = *ptr;
if(c == ':') {
if(ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = atoi(ptr + 1);
ifNum++;
ifC = 0;
}
while(c != ',' && c != '\0')
c = *(++ptr);
} else if(c == ',' || c == '\0') {
if(ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = -1;
ifNum++;
ifC = 0;
}
} else {
ifList[ifNum].prefix[ifC] = c;
ifC++;
}
ptr++;
} while(ifNum < maxList && c);
return ifNum;
}
static bool matchIf(const char* string, const char* ref, bool matchExact) {
// Make sure to include '\0' in the exact case
int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
return strncmp(string, ref, matchLen) == 0;
}
static bool matchPort(const int port1, const int port2) {
if(port1 == -1)
return true;
if(port2 == -1)
return true;
if(port1 == port2)
return true;
return false;
}
static bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) {
// Make an exception for the case where no user list is defined
if(listSize == 0)
return true;
for(int i = 0; i < listSize; i++) {
if(matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
return true;
}
}
return false;
}
#define MAX_IFS 16
#define MAX_IF_NAME_SIZE 16
#define SLEEP_INT 1000 // connection retry sleep interval in usec
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV)
#define SCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL
union scclSocketAddress {
struct sockaddr sa;
struct sockaddr_in sin;
struct sockaddr_in6 sin6;
};
enum scclSocketState {
scclSocketStateNone = 0,
scclSocketStateInitialized = 1,
scclSocketStateAccepting = 2,
scclSocketStateAccepted = 3,
scclSocketStateConnecting = 4,
scclSocketStateConnectPolling = 5,
scclSocketStateConnected = 6,
scclSocketStateReady = 7,
scclSocketStateClosed = 8,
scclSocketStateError = 9,
scclSocketStateNum = 10
};
enum scclSocketType {
scclSocketTypeUnknown = 0,
scclSocketTypeBootstrap = 1,
scclSocketTypeProxy = 2,
scclSocketTypeNetSocket = 3,
scclSocketTypeNetIb = 4
};
struct scclSocket {
int fd;
int acceptFd;
int timedOutRetries;
int refusedRetries;
union scclSocketAddress addr;
volatile uint32_t* abortFlag;
int asyncFlag;
enum scclSocketState state;
int salen;
uint64_t magic;
enum scclSocketType type;
};
const char* scclSocketToString(union scclSocketAddress* addr, char* buf, const int numericHostForm = 1);
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair);
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs);
int scclFindInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs);
// Initialize a socket
scclResult_t scclSocketInit(struct scclSocket* sock,
union scclSocketAddress* addr = NULL,
uint64_t magic = SCCL_SOCKET_MAGIC,
enum scclSocketType type = scclSocketTypeUnknown,
volatile uint32_t* abortFlag = NULL,
int asyncFlag = 0);
// Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call
scclResult_t scclSocketListen(struct scclSocket* sock);
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr);
// Connect to sock->addr. sock->fd is set after a successful call.
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse = 0);
// Return socket connection state.
scclResult_t scclSocketReady(struct scclSocket* sock, int* running);
// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr.
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* ulistenSock);
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd);
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock);
#define SCCL_SOCKET_SEND 0
#define SCCL_SOCKET_RECV 1
scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size);
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size);
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking);
scclResult_t scclSocketClose(struct scclSocket* sock);
#include "socket.h"
#include "debug.h"
#include "check.h"
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>
using namespace sccl;
#define MAX_REQUESTS 8
#define MAX_THREADS 16
#define MAX_SOCKETS 64
struct scclNetSocketTask {
int op;
void* data;
int size;
struct scclSocket* sock;
int offset;
int used;
scclResult_t result;
};
struct scclNetSocketTaskQueue {
int next;
int len;
struct scclNetSocketTask* tasks;
};
struct scclNetSocketRequest {
int op;
void* data;
int size;
struct scclSocket* ctrlSock;
int offset;
int used;
struct scclNetSocketComm* comm;
struct scclNetSocketTask* tasks[MAX_SOCKETS];
int nSubs;
};
struct scclNetSocketThreadResources {
struct scclNetSocketTaskQueue threadTaskQueue;
int stop;
struct scclNetSocketComm* comm;
pthread_mutex_t threadLock;
pthread_cond_t threadCond;
};
struct scclNetSocketComm {
struct scclSocket ctrlSock;
struct scclSocket socks[MAX_SOCKETS];
int dev;
int hipDev;
int nSocks;
int nThreads;
int nextSock;
struct scclNetSocketRequest requests[MAX_REQUESTS];
pthread_t helperThread[MAX_THREADS];
struct scclNetSocketThreadResources threadResources[MAX_THREADS];
};
#define DIVUP(x, y) (((x) + (y) - 1) / (y))
#define MIN_CHUNKSIZE (64 * 1024)
template <typename T>
scclResult_t scclCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
void* p = malloc(nelem * sizeof(T));
if(p == NULL) {
WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
return scclSystemError;
}
memset(p, 0, nelem * sizeof(T));
*ptr = (T*)p;
return scclSuccess;
}
#define scclCalloc(...) scclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
void scclSetThreadName(pthread_t thread, const char* fmt, ...) {
#ifdef _GNU_SOURCE
char threadName[16];
va_list vargs;
va_start(vargs, fmt);
vsnprintf(threadName, 16, fmt, vargs);
va_end(vargs);
pthread_setname_np(thread, threadName);
#endif
}
void* persistentSocketThread(void* args_) {
struct scclNetSocketThreadResources* resource = (struct scclNetSocketThreadResources*)args_;
struct scclNetSocketComm* comm = resource->comm;
struct scclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue;
int nSocksPerThread = comm->nSocks / comm->nThreads;
while(1) {
int idle = 1;
int mark = myQueue->next; // mark newest task seen
for(int i = 0; i < myQueue->len; i += nSocksPerThread) {
int repeat;
do {
repeat = 0;
for(int j = 0; j < nSocksPerThread; j++) {
struct scclNetSocketTask* r = myQueue->tasks + i + j;
if(r != NULL && r->used == 1 && r->offset < r->size) {
r->result = scclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset);
if(r->result != scclSuccess) {
WARN("NET/Socket : socket progress error");
return NULL;
}
idle = 0;
if(r->offset < r->size)
repeat = 1;
}
}
} while(repeat);
}
if(idle) {
pthread_mutex_lock(&resource->threadLock);
while(mark == myQueue->next && resource->stop == 0) { // no new tasks, wait
pthread_cond_wait(&resource->threadCond, &resource->threadLock);
}
pthread_mutex_unlock(&resource->threadLock);
}
if(resource->stop)
return NULL;
}
}
scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req) {
int tid = comm->nextSock % comm->nThreads;
struct scclNetSocketThreadResources* res = comm->threadResources + tid;
struct scclNetSocketTaskQueue* queue = &res->threadTaskQueue;
// create helper threads and prepare per-thread task queue
if(queue->tasks == NULL) {
// each request can be divided up to nSocks tasks, and
// these tasks are distributed to nThreads threads,
// we need to make sure each thread queue has enough slots for MAX_REQUESTS
queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads);
SCCLCHECK(scclCalloc(&queue->tasks, queue->len));
queue->next = 0;
res->comm = comm;
pthread_mutex_init(&res->threadLock, NULL);
pthread_cond_init(&res->threadCond, NULL);
pthread_create(comm->helperThread + tid, NULL, persistentSocketThread, res);
scclSetThreadName(comm->helperThread[tid], "NCCL Sock%c%1u%2u%2u", op == SCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->hipDev);
}
struct scclNetSocketTask* r = queue->tasks + queue->next;
if(r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
r->sock = comm->socks + comm->nextSock;
r->offset = 0;
r->result = scclSuccess;
comm->nextSock = (comm->nextSock + 1) % comm->nSocks;
r->used = 1;
*req = r;
pthread_mutex_lock(&res->threadLock);
queue->next = (queue->next + 1) % queue->len;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
return scclSuccess;
}
WARN("NET/Socket : unable to allocate subtasks");
return scclInternalError;
}
/**
* @brief 测试socket通信请求状态
*
* 该函数用于测试socket通信请求的完成状态,并处理数据传输过程。它会根据请求的不同状态(未开始、正在交换数据大小、已完成交换)执行相应的操作:
* - 如果请求未开始(used=0),则初始化状态
* - 如果正在交换数据大小(used=1),则处理数据大小交换逻辑
* - 如果已完成数据大小交换(used=2),则处理实际数据传输
*
* @param request 指向socket请求的指针
* @param done 输出参数,指示请求是否完成(1=完成,0=未完成)
* @param size 输出参数,返回传输的数据大小
* @return scclResult_t 返回操作结果状态码
*/
scclResult_t scclNetSocketTest(void* request, int* done, int* size) {
*done = 0;
struct scclNetSocketRequest* r = (struct scclNetSocketRequest*)request;
if(r == NULL) {
INFO(SCCL_LOG_CODEALL, "NET/Socket : test called with NULL request");
return scclInternalError;
}
INFO(SCCL_LOG_CODEALL, "NET/Socket : test called request used:%d\n", r->used);
if(r->used == 1) { /* try to send/recv size */
int data = r->size;
int offset = 0;
SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, &data, sizeof(int), &offset));
if(offset == 0)
return scclSuccess; /* Not ready -- retry later */
// Not sure we could ever receive less than 4 bytes, but just in case ...
if(offset < sizeof(int))
SCCLCHECK(scclSocketWait(r->op, r->ctrlSock, &data, sizeof(int), &offset));
// Check size is less or equal to the size provided by the user
if(r->op == SCCL_SOCKET_RECV && data > r->size) {
char line[SOCKET_NAME_MAXLEN + 1];
union scclSocketAddress addr;
scclSocketGetAddr(r->ctrlSock, &addr);
WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \
there may be a mismatch in collective sizes or environment settings (e.g. SCCL_PROTO, SCCL_ALGO) between ranks",
scclSocketToString(&addr, line),
data,
r->size);
return scclInvalidUsage;
}
r->size = data;
r->offset = 0;
r->used = 2; // done exchanging size
// divide into subtasks
int chunkOffset = 0, i = 0;
if(r->comm->nSocks > 0) {
// each request can be divided up to nSocks tasks
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
while(chunkOffset < r->size) {
int chunkSize = std::min(taskSize, r->size - chunkOffset);
SCCLCHECK(scclNetSocketGetTask(r->comm, r->op, (char*)(r->data) + chunkOffset, chunkSize, r->tasks + i++));
chunkOffset += chunkSize;
}
}
r->nSubs = i;
}
if(r->used == 2) { // already exchanged size
if(r->nSubs > 0) {
int nCompleted = 0;
for(int i = 0; i < r->nSubs; i++) {
struct scclNetSocketTask* sub = r->tasks[i];
if(sub->result != scclSuccess)
return sub->result;
if(sub->offset == sub->size)
nCompleted++;
}
if(nCompleted == r->nSubs) {
if(size)
*size = r->size;
*done = 1;
r->used = 0;
for(int i = 0; i < r->nSubs; i++) {
struct scclNetSocketTask* sub = r->tasks[i];
sub->used = 0;
}
}
} else { // progress request using main thread
if(r->offset < r->size) {
SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, r->data, r->size, &r->offset));
}
if(r->offset == r->size) {
if(size)
*size = r->size;
*done = 1;
r->used = 0;
}
}
}
return scclSuccess;
}
int main(int argc, char* argv[]) {
struct scclNetSocketRequest* request = (struct scclNetSocketRequest*)malloc(sizeof(struct scclNetSocketRequest));
request->op = SCCL_SOCKET_SEND;
request->used = 1;
request->size = 1024;
request->data = (char*)malloc(request->size);
request->ctrlSock = NULL;
request->comm = NULL;
request->nSubs = 0;
int done;
int sizes[32];
printf("test\n");
INFO(SCCL_LOG_CODEALL, "test INFO");
SCCLCHECK(scclSocketInit(request));
SCCLCHECK(scclNetSocketTest(request, &done, sizes));
if(done) {
printf("done\n");
}
}
\ No newline at end of file
...@@ -6,11 +6,11 @@ hipcc ./test_topo.cpp \ ...@@ -6,11 +6,11 @@ hipcc ./test_topo.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cc \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cc \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/nvmlwrap.cc \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/nvmlwrap.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvsymbols.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvwrap.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/net_ib.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/socket.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/net_socket.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
...@@ -25,8 +25,8 @@ hipcc ./test_topo.cpp \ ...@@ -25,8 +25,8 @@ hipcc ./test_topo.cpp \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \ -L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -L /usr/lib/ \ -L /usr/lib/x86_64-linux-gnu -L /usr/lib/ \
......
hipcc /public/home/lishen/Code/rocSHMEM/SCCL_v1/examples/2_topo/1_demo_rocm/test_rocm_smi.cpp \ hipcc /public/home/lishen/Code/rocSHMEM/SCCL_v1/examples/2_topo/1_demo_rocm/test_rocm_smi.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/rocm_smi_wrap.cc \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/rocm_smi_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo_utils.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo_utils.cpp \
-o test_topo \ -o test_topo \
-std=c++17 -g -O3 -fopenmp -D__HIP_PLATFORM_HCC__ \ -std=c++17 -g -O3 -fopenmp -D__HIP_PLATFORM_HCC__ \
...@@ -11,6 +11,7 @@ hipcc /public/home/lishen/Code/rocSHMEM/SCCL_v1/examples/2_topo/1_demo_rocm/test ...@@ -11,6 +11,7 @@ hipcc /public/home/lishen/Code/rocSHMEM/SCCL_v1/examples/2_topo/1_demo_rocm/test
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ \
-L /usr/lib/x86_64-linux-gnu \ -L /usr/lib/x86_64-linux-gnu \
-L /usr/lib/ \ -L /usr/lib/ \
-lamdhip64 -lrocm_smi64 -lamdhip64 -lrocm_smi64
\ No newline at end of file
...@@ -11,21 +11,21 @@ using namespace sccl; ...@@ -11,21 +11,21 @@ using namespace sccl;
int main(int argc, char** argv) { int main(int argc, char** argv) {
printf("hello world\n"); printf("hello world\n");
(void)rocm_smi_init(); (void)sccl::hardware::topology::bootstrap::rocm_smi_init();
uint32_t num_devs; uint32_t num_devs;
(void)rocm_smi_getNumDevice(&num_devs); (void)sccl::hardware::topology::bootstrap::rocm_smi_getNumDevice(&num_devs);
printf("num_devs=%d\n", num_devs); printf("num_devs=%d\n", num_devs);
uint32_t deviceIndex = 0; uint32_t deviceIndex = 0;
char bus0[100] = "bus0"; char bus0[100] = "bus0";
(void)rocm_smi_getDevicePciBusIdString(deviceIndex, bus0, 100); (void)sccl::hardware::topology::bootstrap::rocm_smi_getDevicePciBusIdString(deviceIndex, bus0, 100);
printf("bus0=%s\n", bus0); printf("bus0=%s\n", bus0);
RSMI_IO_LINK_TYPE rsmi_type; RSMI_IO_LINK_TYPE rsmi_type;
int hops, count; int hops, count;
(void)rocm_smi_getLinkInfo(0, 8, &rsmi_type, &hops, &count); (void)sccl::hardware::topology::bootstrap::rocm_smi_getLinkInfo(0, 8, &rsmi_type, &hops, &count);
printf("rsmi_type=%d, hops=%d, count=%d\n", rsmi_type, hops, count); printf("rsmi_type=%d, hops=%d, count=%d\n", rsmi_type, hops, count);
// struct sccl::hardware::topology::topo::scclXml* xml; // struct sccl::hardware::topology::topo::scclXml* xml;
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
#include <stdint.h> #include <stdint.h>
#include "mpi.h" #include "mpi.h"
#include "net.h" #include "net.h"
#include "bootstrap_net.h" #include "bootstrap.h"
#include "hardware_utils.h"
using namespace sccl; using namespace sccl;
...@@ -23,17 +24,35 @@ int main(int argc, char* argv[]) { ...@@ -23,17 +24,35 @@ int main(int argc, char* argv[]) {
// ----------------------------------------------------------------------- // // ----------------------------------------------------------------------- //
INFO(SCCL_LOG_TOPO, "Bootstrap ...\n"); INFO(SCCL_LOG_TOPO, "Bootstrap ...\n");
struct scclRankInfo* rank_info;
struct sccl::hardware::topology::bootstrap::scclBootstrapComm* comm;
(void)sccl::hardware::topology::bootstrap::bootstrap_net::bootstrapNetInit(); SCCLCHECK(scclCalloc(&rank_info, 1));
SCCLCHECK(scclCalloc(&comm, 1));
rank_info->rank = rank;
rank_info->nRanks = nranks;
rank_info->localRanks = 2;
rank_info->hipDev = rank % rank_info->localRanks;
auto sccl_bootstrap = new sccl::hardware::topology::bootstrap::scclBootstrap(rank_info, comm);
SCCLCHECK(sccl_bootstrap->bootstrapInitCheck());
sccl::hardware::topology::bootstrap::printUniqueInfo(comm->unique_info);
int cuda_id;
HIPCHECK(hipGetDevice(&cuda_id));
printf("rank=%d, cuda_id=%d\n", rank, cuda_id);
MPI_Finalize(); MPI_Finalize();
} }
/* /*
单机执行 单机执行
SCCL_DEBUG_LEVEL=SCCL_LOG_ABORT mpirun --allow-run-as-root -np 2 1_mpi_init SCCL_DEBUG_LEVEL=ABORT mpirun --allow-run-as-root -np 4 1_mpi_init
SCCL_DEBUG_LEVEL=SCCL_LOG_INFO SCCL_DEBUG_POS=SCCL_LOG_CODEALL mpirun --allow-run-as-root -np 2 1_mpi_init SCCL_DEBUG_LEVEL=INFO SCCL_DEBUG_SUBSYS=ALL mpirun --allow-run-as-root -np 2 1_mpi_init
跨机执行 跨机执行
SCCL_DEBUG_LEVEL=SCCL_LOG_ABORT mpirun --allow-run-as-root --hostfile hostfile -np 16 ./1_mpi_init SCCL_DEBUG_LEVEL=ABORT mpirun --allow-run-as-root --hostfile hostfile -np 16 ./1_mpi_init
SCCL_DEBUG_LEVEL=ABORT SCCL_DEBUG_SUBSYS=BOOTSTRAP mpirun --allow-run-as-root --hostfile hostfile2 -np 4 ./1_mpi_init
*/ */
hipcc ./1_mpi_init.cpp \ hipcc ./1_mpi_init.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvsymbols.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/hardware_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvwrap.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/net_ib.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/socket.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/net_socket.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ipc_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/bootstrap_net.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/bootstrap_net.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ipcsocket.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/bootstrap_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/proxy.cpp \ /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/rocm_smi_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/bootstrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/thread_pool.cpp \
-o 1_mpi_init \ -o 1_mpi_init \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ \ -std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ -Wno-return-type \
-I ./ -I /usr/include -I /opt/dtk/include \ -I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/include/ \ -I /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_ib/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ipc_socket/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ \ -I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/bootstrap/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \ -L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm \ -L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm \
-L /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/lib -lmpi -L /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/lib -lmpi \
-L /opt/dtk/lib -lamdhip64 -lrocm-core -lrocm_smi64 -pthread
export HSA_FORCE_FINE_GRAIN_PCIE="1"
export iommu=pt
node037 slots=8
node038 slots=8
\ No newline at end of file
node037 slots=2
node038 slots=2
\ No newline at end of file
import os
import glob
from pathlib import Path
def find_cpp_files(directory):
return [str(file) for file in Path(directory).rglob('*.cpp')]
def main():
src_path = "/public/home/lishen/Code/rocSHMEM/SCCL_v1/src"
cpp_files = find_cpp_files(src_path)
cpp_files.sort()
for cpp_file in cpp_files:
print(cpp_file+' \\')
if __name__ == "__main__":
main()
#pragma once #pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#include <stdint.h> #include <stdint.h>
#include "base.h" #include "base.h"
#include "topo.h"
namespace sccl { namespace sccl {
namespace hardware { namespace hardware {
// 定义结构体 scclUniqueInfo,用于存储每个通信节点的信息
struct scclUniqueInfo {
int rank; // 当前节点的全局排名
int nRanks; // 总的节点数量
int localRank; // 当前节点在本地计算节点中的排名
int localRanks; // 本地计算节点中的节点总数
int cudaDev; // CUDA 设备 ID
int gdrSupport; // 是否支持 GPU 直接注册 (GDR)
uint64_t hostHash; // 主机哈希值
uint64_t pidHash; // 进程 ID 哈希值
int64_t busId; // 总线 ID
};
// // 定义结构体 scclCommBase,用于存储通信基础信息 // // 定义结构体 scclCommBase,用于存储通信基础信息
// struct scclCommBase { // struct scclCommBase {
// struct scclUniqueInfo* peerInfo; // 指向 peerInfo 结构体的指针,存储所有节点的信息 // struct scclUniqueInfo* peerInfo; // 指向 peerInfo 结构体的指针,存储所有节点的信息
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment