Commit d9d23f34 authored by lishen's avatar lishen
Browse files

Initial Code for SCCL_v1

parent 57df3737
###############################################################################
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
###############################################################################
cmake_minimum_required(VERSION 3.16.3 FATAL_ERROR)
###############################################################################
# AVOID IN SOURCE BUILD
###############################################################################
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR AND
CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
set(MSG "")
message(STATUS "Warning! Building from the source directory is not recommended")
message(STATUS "If unintended, please remove 'CMakeCache.txt' and 'CMakeFiles'")
message(STATUS "and build from a separate directory")
message(FATAL_ERROR "In-source build")
endif()
###############################################################################
# CONFIGURATION OPTIONS
###############################################################################
option(DEBUG "Enable debug trace" OFF)
option(PROFILE "Enable statistics and timing support" OFF)
option(USE_RO "Enable RO conduit." ON)
option(USE_IPC "Enable IPC support (using HIP)" OFF)
option(USE_THREADS "Enable workgroup threads to share network queues" OFF)
option(USE_WF_COAL "Enable wavefront message coalescing" OFF)
option(USE_COHERENT_HEAP "Enable support for coherent systems" OFF)
option(USE_MANAGED_HEAP "Enable managed memory" OFF)
option(USE_HOST_HEAP "Enable host memory using malloc/free" OFF)
option(USE_HIP_HOST_HEAP "Enable host memory using hip api" OFF)
option(USE_ALLOC_DLMALLOC "Enable dlmalloc device memory allocator" ON)
option(USE_ALLOC_POW2BINS "Enable legacy Pow2Bins device memory allocator" OFF)
option(USE_FUNC_CALL "Force compiler to use function calls on library API" OFF)
option(USE_SHARED_CTX "Request support for shared ctx between WG" OFF)
option(USE_SINGLE_NODE "Enable single node support only." OFF)
option(USE_HOST_SIDE_HDP_FLUSH "Use a polling thread to flush the HDP cache on the host." OFF)
option(BUILD_FUNCTIONAL_TESTS "Build the functional tests" ON)
option(BUILD_EXAMPLES "Build the examples" ON)
option(BUILD_UNIT_TESTS "Build the unit tests" ON)
option(BUILD_TESTS_ONLY "Build only tests. Used to link agains rocSHMEM in a ROCm Release" OFF)
option(BUILD_LOCAL_GPU_TARGET_ONLY "Build only for GPUs detected on this machine" OFF)
configure_file(cmake/rocshmem_config.h.in rocshmem_config.h)
###############################################################################
# GLOBAL COMPILE FLAGS
###############################################################################
if (DEFINED ENV{ROCM_PATH})
set(ROCM_PATH "$ENV{ROCM_PATH}" CACHE STRING "ROCm install directory")
else()
set(ROCM_PATH "/opt/rocm" CACHE STRING "ROCm install directory")
endif()
if (NOT DEFINED CMAKE_CXX_COMPILER)
set(CMAKE_CXX_COMPILER ${ROCM_PATH}/bin/hipcc)
endif()
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -ggdb")
if (BUILD_TESTS_ONLY)
if (DEFINED ENV{ROCSHMEM_HOME})
set(ROCSHMEM_HOME "$ENV{ROCSHMEM_HOME}")
else()
message("Environment variable ROCSHMEM_HOME is not set.")
message("Assuming that rocSHMEM is installed at ${ROCM_PATH}.")
set(ROCSHMEM_HOME "${ROCM_PATH}")
endif()
endif()
find_package(ROCM PATHS ${ROCM_PATH})
set(ROCMCHECKS_WARN_TOOLCHAIN_VAR OFF)
include(cmake/rocm_local_targets.cmake)
set(DEFAULT_GPUS
gfx936)
###############################################################################
# PROJECT
###############################################################################
find_package(ROCmCMakeBuildTools)
include(ROCMCreatePackage)
include(ROCMInstallTargets)
include(ROCMCheckTargetIds)
rocm_setup_version(VERSION 2.0.0)
project(rocshmem VERSION 2.0.0 LANGUAGES CXX)
add_compile_options(-Wno-return-type)
###############################################################################
# CREATE ROCSHMEM LIBRARY
###############################################################################
if (NOT BUILD_TESTS_ONLY)
add_library(${PROJECT_NAME})
add_library(roc::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
add_subdirectory(src)
#############################################################################
# SET GPU ARCHITECTURES
#############################################################################
if (BUILD_LOCAL_GPU_TARGET_ONLY)
message(STATUS "Building only for local GPU target")
if (COMMAND rocm_local_targets)
rocm_local_targets(DEFAULT_GPUS)
else()
message(WARNING "Unable to determine local GPU targets. Falling back to default GPUs.")
endif()
endif()
set(GPU_TARGETS "${DEFAULT_GPUS}" CACHE STRING
"Target default GPUs if GPU_TARGETS is not defined.")
if (COMMAND rocm_check_target_ids)
message(STATUS "Checking for ROCm support for GPU targets: " "${GPU_TARGETS}")
rocm_check_target_ids(SUPPORTED_GPUS TARGETS ${GPU_TARGETS})
else()
message(WARNING "Unable to check for supported GPU targets. Falling back to default GPUs.")
set(SUPPORTED_GPUS ${DEFAULT_GPUS})
endif()
set(COMPILING_TARGETS "${SUPPORTED_GPUS}" CACHE STRING "GPU targets to compile for.")
message(STATUS "Compiling for ${COMPILING_TARGETS}")
foreach (target ${COMPILING_TARGETS})
list(APPEND offload_flags --offload-arch=${target})
endforeach()
add_compile_options(${offload_flags})
#############################################################################
# PACKAGE DEPENDENCIES
#############################################################################
find_package(MPI REQUIRED)
find_package(hip REQUIRED)
find_package(hsa-runtime64 REQUIRED)
set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
set(THREADS_PREFER_PTHREAD_FLAG TRUE)
find_package(Threads REQUIRED)
#############################################################################
# LINKING AND INCLUDE DIRECTORIES
#############################################################################
target_include_directories(
${PROJECT_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${CMAKE_BINARY_DIR}> # rocshmem_config.h
$<INSTALL_INTERFACE:include>
${MPI_CXX_HEADER_DIR}
)
target_link_libraries(
${PROJECT_NAME}
PUBLIC
Threads::Threads
${MPI_mpi_LIBRARY}
${MPI_mpicxx_LIBRARY}
hip::device
hip::host
hsa-runtime64::hsa-runtime64
)
endif()
###############################################################################
# TEST SUBDIRECTORIES
###############################################################################
add_subdirectory(tests)
if (BUILD_EXAMPLES)
add_subdirectory(examples)
endif()
if (NOT BUILD_TESTS_ONLY)
#############################################################################
# INSTALL
#############################################################################
include(ROCMInstallTargets)
include(ROCMCreatePackage)
rocm_install(TARGETS rocshmem)
rocm_install(
DIRECTORY ${CMAKE_SOURCE_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
rocm_install(
FILES "${CMAKE_BINARY_DIR}/rocshmem_config.h"
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/rocshmem
)
rocm_package_add_dependencies(
DEPENDS
hsa-rocr
hip-runtime-amd
rocm-dev
)
rocm_export_targets(
TARGETS roc::rocshmem
NAMESPACE roc::
)
rocm_create_package(
NAME "rocSHMEM"
DESCRIPTION "ROCm OpenSHMEM (rocSHMEM)"
MAINTAINER "rocSHMEM Maintainer <rocshmem-maintainer@amd.com>"
)
endif()
# HCU Collective Communication Library (HcuCCL)
所有和汇编、builtin指令相关的内容都在`device/utils`文件夹中,上层应用直接使用
SCCL的整体框架如下
![SCCL_framework](docs/images/sccl_v1.png)
SCCL的topo信息获取过程如下
![SCCL_topo](docs/images/topo信息.png)
hipcc ./thread.cpp \
-o thread \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
void* thread_function(void* arg) {
// 线程开始执行的函数
printf("Thread is running with argument: %s\n", (char*)arg);
return NULL;
}
int main() {
pthread_t thread_id;
const char* message = "Hello, World!";
int result;
// 创建线程
result = pthread_create(&thread_id, NULL, thread_function, (void*)message);
if(result != 0) {
perror("Thread creation failed");
exit(EXIT_FAILURE);
}
printf("Thread created successfully\n");
pthread_exit(NULL); // 等待线程结束
return 0;
}
\ No newline at end of file
#include <iostream>
#include "net.h"
using namespace sccl;
int main(int argc, char* argv[]) {
INFO(SCCL_LOG_CODEALL, "Hello, World!");
// SCCLCHECK(scclSystemError);
// SCCLCHECK(sccl::hardware::net::device::scclIbInit());
// SCCLCHECK(sccl::hardware::net::device::scclIbGetDevicesNum(&n_ib));
// printf("device num=%d\n", n_ib);
// ----------------------------------------------------------------------- //
auto scclNet = sccl::hardware::net::initNet(sccl::hardware::net::NET_IB);
// auto scclNet = sccl::hardware::net::initNet(sccl::hardware::net::NET_SOCKET);
sccl::hardware::net::scclNetProperties_t props;
int n_ib;
scclNet->devices(&n_ib);
printf("device num=%d\n", n_ib);
scclNet->getProperties(0, &props);
printf("device name=%s\n", props.name);
printf("device pciPath=%s\n", props.pciPath);
printf("device guid=%lu\n", props.guid);
printf("device ptrSupport=%d\n", props.ptrSupport);
printf("device speed=%d\n", props.speed);
printf("device port=%d\n", props.port);
printf("device latency=%f\n", props.latency);
printf("device maxComms=%d\n", props.maxComms);
printf("device maxRecvs=%d\n", props.maxRecvs);
// 程序成功执行,返回0
return 0;
}
// HIP_VISIBLE_DEVICES=1 ./1_simple
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include "mpi.h"
#include "net.h"
using namespace sccl;
// int main(int argc, char* argv[]) {
// INFO(SCCL_LOG_CODEALL, "Hello, World!");
// // SCCLCHECK(scclSystemError);
// // SCCLCHECK(sccl::hardware::net::device::scclIbInit());
// // SCCLCHECK(sccl::hardware::net::device::scclIbGetDevicesNum(&n_ib));
// // printf("device num=%d\n", n_ib);
// // ----------------------------------------------------------------------- //
// // auto scclNet = sccl::hardware::net::initNet(sccl::hardware::net::NET_IB);
// auto scclNet = sccl::hardware::net::initNet(sccl::hardware::net::NET_SOCKET);
// sccl::hardware::net::scclNetProperties_t props;
// int n_ib;
// scclNet->devices(&n_ib);
// printf("device num=%d\n", n_ib);
// scclNet->getProperties(0, &props);
// printf("device name=%s\n", props.name);
// printf("device pciPath=%s\n", props.pciPath);
// printf("device guid=%lu\n", props.guid);
// printf("device ptrSupport=%d\n", props.ptrSupport);
// printf("device speed=%d\n", props.speed);
// printf("device port=%d\n", props.port);
// printf("device latency=%f\n", props.latency);
// printf("device maxComms=%d\n", props.maxComms);
// printf("device maxRecvs=%d\n", props.maxRecvs);
// // 程序成功执行,返回0
// return 0;
// }
int main(int argc, char* argv[]) {
int rank, nranks;
int tag1, src, dst, cnt;
MPI_Status status;
MPI_Init(&argc, &argv);
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
printf("rank=%d, nranks=%d\n", rank, nranks);
// ----------------------------------------------------------------------- //
#if 0
{
auto scclNet = sccl::hardware::net::initNet(sccl::hardware::net::NET_SOCKET);
sccl::hardware::net::scclNetProperties_t props;
int n_ib;
scclNet->devices(&n_ib);
int local_rank = rank % n_ib;
scclNet->getProperties(local_rank, &props);
int cuda_dev = local_rank;
char busIdStr[] = "00000000:00:00.0";
(void)hipDeviceGetPCIBusId(busIdStr, sizeof(busIdStr), cuda_dev);
printf("rank=%d/%d, n_ib=%d, device name=%s, bus_id=%s, pciPath=%s,guid=%lu, ptrSupport=%d, speed=%d, port=%d, latency=%f, maxComms=%d, maxRecvs=%d\n",
rank,
nranks,
n_ib,
props.name,
busIdStr,
props.pciPath,
props.guid,
props.ptrSupport,
props.speed,
props.port,
props.latency,
props.maxComms,
props.maxRecvs);
}
#endif
#if 1
{
auto scclNet = sccl::hardware::net::initNet(sccl::hardware::net::NET_IB);
sccl::hardware::net::scclNetProperties_t props;
int n_ib;
scclNet->devices(&n_ib);
int local_rank = rank % n_ib;
scclNet->getProperties(local_rank, &props);
#define MAX_BUSID_SIZE 16
int cuda_dev = local_rank;
char busIdStr[] = "00000000:00:00.0";
(void)hipDeviceGetPCIBusId(busIdStr, sizeof(busIdStr), cuda_dev);
printf("rank=%d/%d, n_ib=%d, device name=%s, bus_id=%s, pciPath=%s,guid=%lu, ptrSupport=%d, speed=%d, port=%d, latency=%f, maxComms=%d, maxRecvs=%d\n",
rank,
nranks,
n_ib,
props.name,
busIdStr,
props.pciPath,
props.guid,
props.ptrSupport,
props.speed,
props.port,
props.latency,
props.maxComms,
props.maxRecvs);
}
#endif
MPI_Finalize();
}
/*
单机执行
SCCL_DEBUG_LEVEL=SCCL_LOG_ABORT mpirun --allow-run-as-root -np 8 2_mpi_get
SCCL_DEBUG_LEVEL=SCCL_LOG_INFO mpirun --allow-run-as-root -np 8 2_mpi_get
跨机执行
SCCL_DEBUG_LEVEL=SCCL_LOG_ABORT mpirun --allow-run-as-root --hostfile hostfile -np 16 ./2_mpi_get
*/
#include <infiniband/verbs.h>
void check_network_connections() {
struct ibv_device** dev_list;
struct ibv_context* context;
struct ibv_port_attr port_attr;
int num_devices, i, port_num;
// 获取设备列表
dev_list = ibv_get_device_list(&num_devices);
if(!dev_list) {
fprintf(stderr, "Failed to get IB device list\n");
return;
}
// 遍历设备列表
for(i = 0; i < num_devices; i++) {
context = ibv_open_device(dev_list[i]);
if(!context) {
fprintf(stderr, "Failed to open device %s\n", ibv_get_device_name(dev_list[i]));
continue;
}
// 假设我们只检查端口 1
port_num = 1;
if(ibv_query_port(context, port_num, &port_attr)) {
fprintf(stderr, "Failed to query port %d attributes on device %s\n", port_num, ibv_get_device_name(dev_list[i]));
ibv_close_device(context);
continue;
}
// 检查端口状态和连接状态
if(port_attr.state == IBV_PORT_ACTIVE && port_attr.phys_state == 5) { // 5 表示端口已连接
printf("Device %s, Port %d is connected.\n", ibv_get_device_name(dev_list[i]), port_num);
} else {
printf("Device %s, Port %d is not connected.\n", ibv_get_device_name(dev_list[i]), port_num);
}
ibv_close_device(context);
}
ibv_free_device_list(dev_list);
}
int main(int argc, char* argv[]) {
// 获取设备列表
check_network_connections();
return 0;
}
hipcc ./2_mpi_get.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
-o 2_mpi_get \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm \
-L /public/home/lishen/Code/rocSHMEM/3rd_party/install/ompi/lib -lmpi
hipcc ./3_rdma_info.cpp \
-o 3_rdma_info \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ \
-I ./ -I /usr/include -I /opt/dtk/include \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm
hipcc ./1_simple.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
-o 1_simple \
-std=c++17 -g -O3 -fopenmp -DROC_SHMEM -D__HIP_PLATFORM_HCC__ \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -libverbs -lrdmacm
node037 slots=8
node038 slots=8
\ No newline at end of file
#include <iostream>
#include <string>
#include <cstring>
#include <unistd.h>
#include <arpa/inet.h>
void start_client(const std::string& server_ip, int server_port) {
int sock = 0;
struct sockaddr_in serv_addr;
char buffer[1024] = {0};
std::string message = "你好,服务器!";
// 创建 socket 文件描述符
if((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
std::cerr << "Socket creation error" << std::endl;
exit(EXIT_FAILURE);
}
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(server_port);
// 转换 IPv4 和 IPv6 地址
if(inet_pton(AF_INET, server_ip.c_str(), &serv_addr.sin_addr) <= 0) {
std::cerr << "Invalid address/ Address not supported" << std::endl;
close(sock);
exit(EXIT_FAILURE);
}
// 连接到服务器
if(connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) {
std::cerr << "Connection Failed" << std::endl;
close(sock);
exit(EXIT_FAILURE);
}
// 发送数据
send(sock, message.c_str(), message.length(), 0);
std::cout << "消息已发送" << std::endl;
// 接收响应
int valread = read(sock, buffer, 1024);
std::cout << "收到的响应: " << buffer << std::endl;
// 关闭连接
close(sock);
}
int main() {
std::string server_ip = "10.16.1.37";
int server_port = 6842;
start_client(server_ip, server_port);
return 0;
}
\ No newline at end of file
hipcc ./test_socket_itf.cpp \
./socket.cpp \
-o test_socket_itf \
-std=c++17 --offload-arch=gfx936 -g -O3 -fopenmp -D__HIP_PLATFORM_HCC__ \
-I ./ \
-I /usr/include \
-I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/ \
-L /usr/lib/x86_64-linux-gnu -lpthread -lrt
\ No newline at end of file
#include <iostream>
#include <ifaddrs.h>
#include <arpa/inet.h>
#include <net/if.h>
#include <stdlib.h>
#include <netdb.h>
#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include <unistd.h>
#include <sys/syscall.h>
#define NI_MAXHOST 1025
void get_ip_addresses() {
struct ifaddrs *ifaddr, *ifa;
char host[NI_MAXHOST];
if(getifaddrs(&ifaddr) == -1) {
perror("getifaddrs");
exit(EXIT_FAILURE);
}
for(ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
if(ifa->ifa_addr == NULL)
continue;
if(ifa->ifa_addr->sa_family == AF_INET) { // 检查是否为 IPv4 地址
(void)getnameinfo(ifa->ifa_addr, sizeof(struct sockaddr_in), host, NI_MAXHOST, NULL, 0, NI_NUMERICHOST);
std::cout << "Interface: " << ifa->ifa_name << " Address: " << host << std::endl;
}
}
freeifaddrs(ifaddr);
}
int main() {
get_ip_addresses();
return 0;
}
\ No newline at end of file
#include <iostream>
#include <string>
#include <cstring>
#include <unistd.h>
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
void start_server() {
int server_fd, new_socket;
struct sockaddr_in address;
int addrlen = sizeof(address);
char buffer[1024] = {0};
std::string message = "消息已收到";
// 创建 socket 文件描述符
if((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
perror("socket failed");
exit(EXIT_FAILURE);
}
// 绑定地址和端口
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY; // 自动获取所有 IP 地址
address.sin_port = htons(6842);
if(bind(server_fd, (struct sockaddr*)&address, sizeof(address)) < 0) {
perror("bind failed");
close(server_fd);
exit(EXIT_FAILURE);
}
// 获取绑定的端口号
socklen_t len = sizeof(address);
if(getsockname(server_fd, (struct sockaddr*)&address, &len) == -1) {
perror("getsockname failed");
close(server_fd);
exit(EXIT_FAILURE);
}
int port = ntohs(address.sin_port);
std::cout << "服务器已启动,端口: " << port << std::endl;
// 监听连接
if(listen(server_fd, 3) < 0) {
perror("listen");
close(server_fd);
exit(EXIT_FAILURE);
}
std::cout << "等待连接..." << std::endl;
// 接受客户端连接
if((new_socket = accept(server_fd, (struct sockaddr*)&address, (socklen_t*)&addrlen)) < 0) {
perror("accept");
close(server_fd);
exit(EXIT_FAILURE);
}
while(true) {
// 接收数据
int valread = read(new_socket, buffer, 1024);
if(valread == 0) {
break;
}
std::cout << "收到的消息: " << buffer << std::endl;
send(new_socket, message.c_str(), message.length(), 0);
memset(buffer, 0, sizeof(buffer));
}
// 关闭连接
close(new_socket);
close(server_fd);
}
int main() {
start_server();
return 0;
}
\ No newline at end of file
#include "socket.h"
#include <stdlib.h>
#include <unistd.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include <unistd.h>
#include <sys/syscall.h>
using namespace sccl;
static std::vector<std::pair<int, std::unordered_set<std::string>>> clientPortPool;
static scclResult_t socketProgressOpt(int op, struct scclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) {
int bytes = 0;
*closed = 0;
char* data = (char*)ptr;
char line[SOCKET_NAME_MAXLEN + 1];
do {
if(op == SCCL_SOCKET_RECV)
bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
if(op == SCCL_SOCKET_SEND)
bytes = send(sock->fd, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
if(op == SCCL_SOCKET_RECV && bytes == 0) {
*closed = 1;
return scclSuccess;
}
if(bytes == -1) {
if(errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
WARN("socketProgressOpt: Call to recv from %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
return scclRemoteError;
} else {
bytes = 0;
}
}
(*offset) += bytes;
if(sock->abortFlag && *sock->abortFlag != 0) {
INFO(SCCL_LOG_CODEALL, "socketProgressOpt: abort called");
return scclInternalError;
}
} while(bytes > 0 && (*offset) < size);
return scclSuccess;
}
static scclResult_t socketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
int closed;
SCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed));
if(closed) {
char line[SOCKET_NAME_MAXLEN + 1];
WARN("socketProgress: Connection closed by remote peer %s", scclSocketToString(&sock->addr, line, 0));
return scclRemoteError;
}
return scclSuccess;
}
static scclResult_t socketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
while(*offset < size)
SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
return scclSuccess;
}
/* Format a string representation of a (union scclSocketAddress *) socket address using getnameinfo()
*
* Output: "IPv4/IPv6 address<port>"
*/
const char* scclSocketToString(union scclSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) {
if(buf == NULL || addr == NULL)
return NULL;
struct sockaddr* saddr = &addr->sa;
if(saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
buf[0] = '\0';
return buf;
}
char host[NI_MAXHOST], service[NI_MAXSERV];
/* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned.
* (When not set, this will still happen in case the node's name cannot be determined.)
*/
int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0);
(void)getnameinfo(saddr, sizeof(union scclSocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag);
sprintf(buf, "%s<%s>", host, service);
return buf;
}
static uint16_t socketToPort(union scclSocketAddress* addr) {
struct sockaddr* saddr = &addr->sa;
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
}
/* Allow the user to force the IPv4/IPv6 interface selection */
static int envSocketFamily(void) {
int family = -1; // Family selection is not forced, will use first one found
char* env = getenv("SCCL_SOCKET_FAMILY");
if(env == NULL)
return family;
INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_FAMILY set by environment to %s", env);
if(strcmp(env, "AF_INET") == 0)
family = AF_INET; // IPv4
else if(strcmp(env, "AF_INET6") == 0)
family = AF_INET6; // IPv6
return family;
}
static int findInterfaces(const char* prefixList, char* names, union scclSocketAddress* addrs, int sock_family, int maxIfNameSize, int maxIfs) {
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
if(searchNot)
prefixList++;
bool searchExact = prefixList && prefixList[0] == '=';
if(searchExact)
prefixList++;
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for(interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
if(interface->ifa_addr == NULL)
continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if(family != AF_INET && family != AF_INET6)
continue;
/* Allow the caller to force the socket family type */
if(sock_family != -1 && family != sock_family)
continue;
/* We also need to skip IPv6 loopback interfaces */
if(family == AF_INET6) {
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
if(IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr))
continue;
}
// check against user specified interfaces
if(!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
// Check that this interface has not already been saved
// getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
bool duplicate = false;
for(int i = 0; i < found; i++) {
if(strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) {
duplicate = true;
break;
}
}
if(!duplicate) {
// Store the interface name
strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize);
// Store the IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memcpy(addrs + found, interface->ifa_addr, salen);
found++;
}
}
freeifaddrs(interfaces);
return found;
}
static bool matchSubnet(struct ifaddrs local_if, union scclSocketAddress* remote) {
/* Check family first */
int family = local_if.ifa_addr->sa_family;
if(family != remote->sa.sa_family) {
return false;
}
if(family == AF_INET) {
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
struct sockaddr_in& remote_addr = remote->sin;
struct in_addr local_subnet, remote_subnet;
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
} else if(family == AF_INET6) {
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
struct sockaddr_in6& remote_addr = remote->sin6;
struct in6_addr& local_in6 = local_addr->sin6_addr;
struct in6_addr& mask_in6 = mask->sin6_addr;
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
bool same = true;
int len = 16; // IPv6 address is 16 unsigned char
for(int c = 0; c < len; c++) { // Network byte order is big-endian
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
if(c1 ^ c2) {
same = false;
break;
}
}
// At last, we need to compare scope id
// Two Link-type addresses can have the same subnet address even though they are not in the same scope
// For Global type, this field is 0, so a comparison wouldn't matter
same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
return same;
} else {
WARN("Net : Unsupported address family type");
return false;
}
}
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
char line_a[SOCKET_NAME_MAXLEN + 1];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for(interface = interfaces; interface && !found; interface = interface->ifa_next) {
if(interface->ifa_addr == NULL)
continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if(family != AF_INET && family != AF_INET6)
continue;
// check against user specified interfaces
if(!matchSubnet(*interface, remoteAddr)) {
continue;
}
// Store the local IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memcpy(localAddrs + found, interface->ifa_addr, salen);
// Store the interface name
strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
found++;
if(found == maxIfs)
break;
}
if(found == 0) {
WARN("Net : No interface found in the same subnet as remote address %s", scclSocketToString(remoteAddr, line_a));
}
freeifaddrs(interfaces);
return found;
}
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair) {
if(!(ip_port_pair && strlen(ip_port_pair) > 1)) {
WARN("Net : string is null");
return scclInvalidArgument;
}
bool ipv6 = ip_port_pair[0] == '[';
/* Construct the sockaddress structure */
if(!ipv6) {
struct netIf ni;
// parse <ip_or_hostname>:<port> string, expect one pair
if(parseStringList(ip_port_pair, &ni, 1) != 1) {
WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
return scclInvalidArgument;
}
struct addrinfo hints, *p;
int rv;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
return scclInvalidArgument;
}
// use the first
if(p->ai_family == AF_INET) {
struct sockaddr_in& sin = ua->sin;
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
sin.sin_family = AF_INET; // IPv4
// inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
sin.sin_port = htons(ni.port); // port
} else if(p->ai_family == AF_INET6) {
struct sockaddr_in6& sin6 = ua->sin6;
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
sin6.sin6_family = AF_INET6; // IPv6
sin6.sin6_port = htons(ni.port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = 0; // should be global scope, set to 0
} else {
WARN("Net : unsupported IP family");
return scclInvalidArgument;
}
freeaddrinfo(p); // all done with this structure
} else {
int i, j = -1, len = strlen(ip_port_pair);
for(i = 1; i < len; i++) {
if(ip_port_pair[i] == '%')
j = i;
if(ip_port_pair[i] == ']')
break;
}
if(i == len) {
WARN("Net : No valid [IPv6]:port pair found");
return scclInvalidArgument;
}
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
memset(ip_str, '\0', sizeof(ip_str));
memset(port_str, '\0', sizeof(port_str));
memset(if_name, '\0', sizeof(if_name));
strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
int port = atoi(port_str);
if(!global_scope)
strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name
struct sockaddr_in6& sin6 = ua->sin6;
sin6.sin6_family = AF_INET6; // IPv6
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
sin6.sin6_port = htons(port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
}
return scclSuccess;
}
int scclFindInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) {
static int shownIfName = 0;
int nIfs = 0;
// Allow user to force the INET socket family selection
int sock_family = envSocketFamily();
// User specified interface
char* env = getenv("SCCL_SOCKET_IFNAME");
if(env && strlen(env) > 1) {
INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_IFNAME set by environment to %s", env);
// Specified by user : find or fail
if(shownIfName++ == 0)
INFO(SCCL_LOG_CODEALL, "SCCL_SOCKET_IFNAME set to %s", env);
nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
} else {
// Try to automatically pick the right one
// Start with IB
nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// else see if we can get some hint from COMM ID
if(nIfs == 0) {
char* commId = getenv("SCCL_COMM_ID");
if(commId && strlen(commId) > 1) {
INFO(SCCL_LOG_CODEALL, "SCCL_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
union scclSocketAddress idAddr;
scclSocketGetAddrFromString(&idAddr, commId);
nIfs = scclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
}
}
// Then look for anything else (but not docker or lo)
if(nIfs == 0)
nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// Finally look for docker, then lo.
if(nIfs == 0)
nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
if(nIfs == 0)
nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
}
return nIfs;
}
scclResult_t scclSocketListen(struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketListen: pass NULL socket");
return scclInvalidArgument;
}
if(sock->fd == -1) {
WARN("scclSocketListen: file descriptor is -1");
return scclInvalidArgument;
}
if(socketToPort(&sock->addr)) {
// Port is forced by env. Make sure we get the port.
int opt = 1;
#if defined(SO_REUSEPORT)
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
#else
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
#endif
}
// addr port should be 0 (Any port)
SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind");
/* Get the assigned Port */
socklen_t size = sock->salen;
SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname");
/* Put the socket in listen mode
* NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
*/
SYSCHECK(listen(sock->fd, 16384), "listen");
sock->state = scclSocketStateReady;
return scclSuccess;
}
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr) {
if(sock == NULL) {
WARN("scclSocketGetAddr: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady)
return scclInternalError;
memcpy(addr, &sock->addr, sizeof(union scclSocketAddress));
return scclSuccess;
}
static scclResult_t socketTryAccept(struct scclSocket* sock) {
socklen_t socklen = sizeof(union scclSocketAddress);
sock->fd = accept(sock->acceptFd, &sock->addr.sa, &socklen);
if(sock->fd != -1) {
sock->state = scclSocketStateAccepted;
} else if(errno != EAGAIN && errno != EWOULDBLOCK) {
WARN("socketTryAccept: Accept failed: %s", strerror(errno));
return scclSystemError;
}
return scclSuccess;
}
static scclResult_t socketFinalizeAccept(struct scclSocket* sock) {
uint64_t magic;
enum scclSocketType type;
int received = 0;
const int one = 1;
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
if(received == 0)
return scclSuccess;
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
if(magic != sock->magic) {
WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic);
close(sock->fd);
sock->fd = -1;
// Ignore spurious connection and accept again
sock->state = scclSocketStateAccepting;
return scclSuccess;
} else {
received = 0;
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &type, sizeof(type), &received));
if(type != sock->type) {
WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type);
sock->state = scclSocketStateError;
close(sock->fd);
sock->fd = -1;
return scclInternalError;
} else {
sock->state = scclSocketStateReady;
}
}
return scclSuccess;
}
static scclResult_t socketStartConnect(struct scclSocket* sock) {
/* blocking/non-blocking connect() is determined by asyncFlag. */
int ret = connect(sock->fd, &sock->addr.sa, sock->salen);
if(ret == 0) {
sock->state = scclSocketStateConnected;
return scclSuccess;
} else if(errno == EINPROGRESS) {
sock->state = scclSocketStateConnectPolling;
return scclSuccess;
} else if(errno == ECONNREFUSED) {
if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
sock->state = scclSocketStateError;
WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
if(sock->refusedRetries % 1000 == 0)
INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
return scclSuccess;
} else if(errno == ETIMEDOUT) {
if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
sock->state = scclSocketStateError;
WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
return scclSuccess;
} else {
char line[SOCKET_NAME_MAXLEN + 1];
sock->state = scclSocketStateError;
WARN("socketStartConnect: Connect to %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
return scclSystemError;
}
}
static scclResult_t socketPollConnect(struct scclSocket* sock) {
struct pollfd pfd;
int timeout = 1, ret;
socklen_t rlen = sizeof(int);
memset(&pfd, 0, sizeof(struct pollfd));
pfd.fd = sock->fd;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, timeout);
if(ret == 0 || (ret < 0 && errno == EINTR)) {
return scclSuccess;
} else if(ret < 0) {
WARN("socketPollConnect poll() failed with error %s", strerror(errno));
return scclRemoteError;
} else {
EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0);
}
/* check socket status */
SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt");
if(ret == 0) {
sock->state = scclSocketStateConnected;
} else if(ret == ECONNREFUSED) {
if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
sock->state = scclSocketStateError;
WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries);
return scclRemoteError;
}
if(sock->refusedRetries % 1000 == 0)
INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
usleep(SLEEP_INT);
sock->state = scclSocketStateConnecting;
} else if(ret == ETIMEDOUT) {
if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
sock->state = scclSocketStateError;
WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
sock->state = scclSocketStateConnecting;
} else if(ret != EINPROGRESS) {
sock->state = scclSocketStateError;
return scclSystemError;
}
return scclSuccess;
}
scclResult_t scclSocketPollConnect(struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketPollConnect: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socketPollConnect(sock));
return scclSuccess;
}
static scclResult_t socketFinalizeConnect(struct scclSocket* sock) {
int sent = 0;
SCCLCHECK(socketProgress(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
if(sent == 0)
return scclSuccess;
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
sent = 0;
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
sock->state = scclSocketStateReady;
return scclSuccess;
}
static scclResult_t socketProgressState(struct scclSocket* sock) {
if(sock->state == scclSocketStateAccepting) {
SCCLCHECK(socketTryAccept(sock));
}
if(sock->state == scclSocketStateAccepted) {
SCCLCHECK(socketFinalizeAccept(sock));
}
if(sock->state == scclSocketStateConnecting) {
SCCLCHECK(socketStartConnect(sock));
}
if(sock->state == scclSocketStateConnectPolling) {
SCCLCHECK(socketPollConnect(sock));
}
if(sock->state == scclSocketStateConnected) {
SCCLCHECK(socketFinalizeConnect(sock));
}
return scclSuccess;
}
scclResult_t scclSocketReady(struct scclSocket* sock, int* running) {
if(sock == NULL) {
*running = 0;
return scclSuccess;
}
if(sock->state == scclSocketStateError || sock->state == scclSocketStateClosed) {
WARN("scclSocketReady: unexpected socket state %d", sock->state);
return scclRemoteError;
}
*running = (sock->state == scclSocketStateReady) ? 1 : 0;
if(*running == 0) {
SCCLCHECK(socketProgressState(sock));
*running = (sock->state == scclSocketStateReady) ? 1 : 0;
}
return scclSuccess;
}
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse) {
char line[SOCKET_NAME_MAXLEN + 1];
const int one = 1;
if(sock == NULL) {
WARN("scclSocketConnect: pass NULL socket");
return scclInvalidArgument;
}
if(sock->fd == -1) {
WARN("scclSocketConnect: file descriptor is -1");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateInitialized) {
WARN("scclSocketConnect: wrong socket state %d", sock->state);
if(sock->state == scclSocketStateError)
return scclRemoteError;
return scclInternalError;
}
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
if(portReuse) {
int family = sock->addr.sa.sa_family;
if(family != AF_INET && family != AF_INET6) {
WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
scclSocketToString(&sock->addr, line),
family,
AF_INET,
AF_INET6);
return scclInternalError;
}
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in)
: sizeof(struct sockaddr_in6); // pre-define ports according to tid, to avoid extra lock for race condition
if(clientPortPool.size() == 0) {
for(int tid = syscall(SYS_gettid), i = 1; i < 5; i++) {
clientPortPool.push_back(std::make_pair(60000 + i * 1000 + tid % 1000, std::unordered_set<std::string>()));
}
}
// find a port without conflict (different remote peer) in best effort
int reused_port = -1;
std::string remote_peer(scclSocketToString(&sock->addr, line));
for(auto& port : clientPortPool) {
if(port.second.find(remote_peer) == port.second.end()) {
reused_port = port.first;
port.second.insert(remote_peer);
break;
}
}
// bind the port in fd for connect system call
if(reused_port != -1) {
int opt = 1;
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
struct sockaddr_in sin;
sin.sin_family = family;
sin.sin_addr.s_addr = htonl(INADDR_ANY);
sin.sin_port = htons(reused_port);
SYSCHECK(bind(sock->fd, (struct sockaddr*)&sin, salen), "bind_client_port");
}
}
sock->state = scclSocketStateConnecting;
do {
SCCLCHECK(socketProgressState(sock));
} while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
(sock->state == scclSocketStateConnecting || sock->state == scclSocketStateConnectPolling || sock->state == scclSocketStateConnected));
if(sock->abortFlag && *sock->abortFlag != 0)
return scclInternalError;
switch(sock->state) {
case scclSocketStateConnecting:
case scclSocketStateConnectPolling:
case scclSocketStateConnected:
case scclSocketStateReady: return scclSuccess;
case scclSocketStateError: return scclSystemError;
default: WARN("scclSocketConnect: wrong socket state %d", sock->state); return scclInternalError;
}
}
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* listenSock) {
scclResult_t ret = scclSuccess;
if(listenSock == NULL || sock == NULL) {
WARN("scclSocketAccept: pass NULL socket");
ret = scclInvalidArgument;
goto exit;
}
if(listenSock->state != scclSocketStateReady) {
WARN("scclSocketAccept: wrong socket state %d", listenSock->state);
if(listenSock->state == scclSocketStateError)
ret = scclSystemError;
else
ret = scclInternalError;
goto exit;
}
if(sock->acceptFd == -1) {
memcpy(sock, listenSock, sizeof(struct scclSocket));
sock->acceptFd = listenSock->fd;
sock->state = scclSocketStateAccepting;
}
do {
SCCLCHECKGOTO(socketProgressState(sock), ret, exit);
} while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
(sock->state == scclSocketStateAccepting || sock->state == scclSocketStateAccepted));
if(sock->abortFlag && *sock->abortFlag != 0)
return scclInternalError;
switch(sock->state) {
case scclSocketStateAccepting:
case scclSocketStateAccepted:
case scclSocketStateReady: ret = scclSuccess; break;
case scclSocketStateError: ret = scclSystemError; break;
default:
WARN("scclSocketAccept: wrong socket state %d", sock->state);
ret = scclInternalError;
break;
}
exit:
return ret;
}
scclResult_t
scclSocketInit(struct scclSocket* sock, union scclSocketAddress* addr, uint64_t magic, enum scclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) {
scclResult_t ret = scclSuccess;
if(sock == NULL)
goto exit;
sock->timedOutRetries = 0;
sock->refusedRetries = 0;
sock->abortFlag = abortFlag;
sock->asyncFlag = asyncFlag;
sock->state = scclSocketStateInitialized;
sock->magic = magic;
sock->type = type;
sock->fd = -1;
sock->acceptFd = -1;
if(addr) {
/* IPv4/IPv6 support */
int family;
memcpy(&sock->addr, addr, sizeof(union scclSocketAddress));
family = sock->addr.sa.sa_family;
if(family != AF_INET && family != AF_INET6) {
char line[SOCKET_NAME_MAXLEN + 1];
WARN("scclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
scclSocketToString(&sock->addr, line),
family,
AF_INET,
AF_INET6);
ret = scclInternalError;
goto fail;
}
sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
/* Connect to a hostname / port */
sock->fd = socket(family, SOCK_STREAM, 0);
if(sock->fd == -1) {
WARN("scclSocketInit: Socket creation failed : %s", strerror(errno));
ret = scclSystemError;
goto fail;
}
} else {
memset(&sock->addr, 0, sizeof(union scclSocketAddress));
}
/* Set socket as non-blocking if async or if we need to be able to abort */
if((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) {
int flags;
EQCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), -1, ret, fail);
SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), ret, fail);
}
exit:
return ret;
fail:
goto exit;
}
scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
if(sock == NULL) {
WARN("scclSocketProgress: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
return scclSuccess;
}
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
if(sock == NULL) {
WARN("scclSocketWait: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socketWait(op, sock, ptr, size, offset));
return scclSuccess;
}
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketSend: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady) {
WARN("scclSocketSend: socket state (%d) is not ready", sock->state);
return scclInternalError;
}
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, ptr, size, &offset));
return scclSuccess;
}
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketRecv: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady) {
WARN("scclSocketRecv: socket state (%d) is not ready", sock->state);
return scclInternalError;
}
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, ptr, size, &offset));
return scclSuccess;
}
// Receive or detect connection closed
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketTryRecv: pass NULL socket");
return scclInvalidArgument;
}
*closed = 0;
// Block until connection closes or nbytes received
if(blocking) {
while(offset < size) {
SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
}
} else {
SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
// If any bytes were received, block waiting for the rest
if(offset > 0) {
while(offset < size) {
SCCLCHECK(socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
}
// No bytes were received, return scclInProgress
} else {
return scclInProgress;
}
}
return scclSuccess;
}
scclResult_t scclSocketClose(struct scclSocket* sock) {
if(sock != NULL) {
if(sock->fd >= 0) {
/* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected
* by refcount of fd, but close() is. close() won't close a fd and send FIN packet if
* the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful
* connection close here. */
shutdown(sock->fd, SHUT_RDWR);
close(sock->fd);
}
sock->state = scclSocketStateClosed;
sock->fd = -1;
}
return scclSuccess;
}
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd) {
if(sock == NULL) {
WARN("scclSocketGetFd: pass NULL socket");
return scclInvalidArgument;
}
if(fd)
*fd = sock->fd;
return scclSuccess;
}
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketGetFd: pass NULL socket");
return scclInvalidArgument;
}
sock->fd = fd;
return scclSuccess;
}
#pragma once
#include "debug.h"
#include "check.h"
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>
using namespace sccl;
struct netIf {
char prefix[64];
int port;
};
static thread_local int scclDebugNoWarn = 0;
#define SYSCHECK(call, name) \
do { \
int retval; \
SYSCHECKVAL(call, name, retval); \
} while(false)
#define SYSCHECKVAL(call, name, retval) \
do { \
SYSCHECKSYNC(call, name, retval); \
if(retval == -1) { \
WARN("Call to " name " failed : %s", strerror(errno)); \
return scclSystemError; \
} \
} while(false)
#define SYSCHECKSYNC(call, name, retval) \
do { \
retval = call; \
if(retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
INFO(SCCL_LOG_CODEALL, "Call to " name " returned %s, retrying", strerror(errno)); \
} else { \
break; \
} \
} while(true)
#define EQCHECK(statement, value) \
do { \
if((statement) == value) { \
/* Print the back trace*/ \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, scclSystemError, strerror(errno)); \
return scclSystemError; \
} \
} while(0);
#define NEQCHECKGOTO(statement, value, RES, label) \
do { \
if((statement) != value) { \
/* Print the back trace*/ \
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
#define SYSCHECKGOTO(statement, RES, label) \
do { \
if((statement) == -1) { \
/* Print the back trace*/ \
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
#define SCCLCHECKGOTO(call, RES, label) \
do { \
RES = call; \
if(RES != scclSuccess && RES != scclInProgress) { \
/* Print the back trace*/ \
if(scclDebugNoWarn == 0) \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d", __FILE__, __LINE__, RES); \
goto label; \
} \
INFO(SCCL_LOG_CODEALL, "check pass %s:%d -> %d", __FILE__, __LINE__, RES); \
} while(0);
#define EQCHECKGOTO(statement, value, RES, label) \
do { \
if((statement) == value) { \
/* Print the back trace*/ \
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
static int parseStringList(const char* string, struct netIf* ifList, int maxList) {
if(!string)
return 0;
const char* ptr = string;
int ifNum = 0;
int ifC = 0;
char c;
do {
c = *ptr;
if(c == ':') {
if(ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = atoi(ptr + 1);
ifNum++;
ifC = 0;
}
while(c != ',' && c != '\0')
c = *(++ptr);
} else if(c == ',' || c == '\0') {
if(ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = -1;
ifNum++;
ifC = 0;
}
} else {
ifList[ifNum].prefix[ifC] = c;
ifC++;
}
ptr++;
} while(ifNum < maxList && c);
return ifNum;
}
static bool matchIf(const char* string, const char* ref, bool matchExact) {
// Make sure to include '\0' in the exact case
int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
return strncmp(string, ref, matchLen) == 0;
}
static bool matchPort(const int port1, const int port2) {
if(port1 == -1)
return true;
if(port2 == -1)
return true;
if(port1 == port2)
return true;
return false;
}
static bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) {
// Make an exception for the case where no user list is defined
if(listSize == 0)
return true;
for(int i = 0; i < listSize; i++) {
if(matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
return true;
}
}
return false;
}
#define MAX_IFS 16
#define MAX_IF_NAME_SIZE 16
#define SLEEP_INT 1000 // connection retry sleep interval in usec
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV)
#define SCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL
union scclSocketAddress {
struct sockaddr sa;
struct sockaddr_in sin;
struct sockaddr_in6 sin6;
};
enum scclSocketState {
scclSocketStateNone = 0,
scclSocketStateInitialized = 1,
scclSocketStateAccepting = 2,
scclSocketStateAccepted = 3,
scclSocketStateConnecting = 4,
scclSocketStateConnectPolling = 5,
scclSocketStateConnected = 6,
scclSocketStateReady = 7,
scclSocketStateClosed = 8,
scclSocketStateError = 9,
scclSocketStateNum = 10
};
enum scclSocketType {
scclSocketTypeUnknown = 0,
scclSocketTypeBootstrap = 1,
scclSocketTypeProxy = 2,
scclSocketTypeNetSocket = 3,
scclSocketTypeNetIb = 4
};
struct scclSocket {
int fd;
int acceptFd;
int timedOutRetries;
int refusedRetries;
union scclSocketAddress addr;
volatile uint32_t* abortFlag;
int asyncFlag;
enum scclSocketState state;
int salen;
uint64_t magic;
enum scclSocketType type;
};
const char* scclSocketToString(union scclSocketAddress* addr, char* buf, const int numericHostForm = 1);
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair);
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs);
int scclFindInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs);
// Initialize a socket
scclResult_t scclSocketInit(struct scclSocket* sock,
union scclSocketAddress* addr = NULL,
uint64_t magic = SCCL_SOCKET_MAGIC,
enum scclSocketType type = scclSocketTypeUnknown,
volatile uint32_t* abortFlag = NULL,
int asyncFlag = 0);
// Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call
scclResult_t scclSocketListen(struct scclSocket* sock);
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr);
// Connect to sock->addr. sock->fd is set after a successful call.
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse = 0);
// Return socket connection state.
scclResult_t scclSocketReady(struct scclSocket* sock, int* running);
// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr.
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* ulistenSock);
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd);
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock);
#define SCCL_SOCKET_SEND 0
#define SCCL_SOCKET_RECV 1
scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size);
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size);
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking);
scclResult_t scclSocketClose(struct scclSocket* sock);
#include "socket.h"
#include "debug.h"
#include "check.h"
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>
using namespace sccl;
#define MAX_REQUESTS 8
#define MAX_THREADS 16
#define MAX_SOCKETS 64
struct scclNetSocketTask {
int op;
void* data;
int size;
struct scclSocket* sock;
int offset;
int used;
scclResult_t result;
};
struct scclNetSocketTaskQueue {
int next;
int len;
struct scclNetSocketTask* tasks;
};
struct scclNetSocketRequest {
int op;
void* data;
int size;
struct scclSocket* ctrlSock;
int offset;
int used;
struct scclNetSocketComm* comm;
struct scclNetSocketTask* tasks[MAX_SOCKETS];
int nSubs;
};
struct scclNetSocketThreadResources {
struct scclNetSocketTaskQueue threadTaskQueue;
int stop;
struct scclNetSocketComm* comm;
pthread_mutex_t threadLock;
pthread_cond_t threadCond;
};
struct scclNetSocketComm {
struct scclSocket ctrlSock;
struct scclSocket socks[MAX_SOCKETS];
int dev;
int hipDev;
int nSocks;
int nThreads;
int nextSock;
struct scclNetSocketRequest requests[MAX_REQUESTS];
pthread_t helperThread[MAX_THREADS];
struct scclNetSocketThreadResources threadResources[MAX_THREADS];
};
#define DIVUP(x, y) (((x) + (y) - 1) / (y))
#define MIN_CHUNKSIZE (64 * 1024)
template <typename T>
scclResult_t scclCallocDebug(T** ptr, size_t nelem, const char* filefunc, int line) {
void* p = malloc(nelem * sizeof(T));
if(p == NULL) {
WARN("Failed to malloc %ld bytes", nelem * sizeof(T));
return scclSystemError;
}
memset(p, 0, nelem * sizeof(T));
*ptr = (T*)p;
return scclSuccess;
}
#define scclCalloc(...) scclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
void scclSetThreadName(pthread_t thread, const char* fmt, ...) {
#ifdef _GNU_SOURCE
char threadName[16];
va_list vargs;
va_start(vargs, fmt);
vsnprintf(threadName, 16, fmt, vargs);
va_end(vargs);
pthread_setname_np(thread, threadName);
#endif
}
void* persistentSocketThread(void* args_) {
struct scclNetSocketThreadResources* resource = (struct scclNetSocketThreadResources*)args_;
struct scclNetSocketComm* comm = resource->comm;
struct scclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue;
int nSocksPerThread = comm->nSocks / comm->nThreads;
while(1) {
int idle = 1;
int mark = myQueue->next; // mark newest task seen
for(int i = 0; i < myQueue->len; i += nSocksPerThread) {
int repeat;
do {
repeat = 0;
for(int j = 0; j < nSocksPerThread; j++) {
struct scclNetSocketTask* r = myQueue->tasks + i + j;
if(r != NULL && r->used == 1 && r->offset < r->size) {
r->result = scclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset);
if(r->result != scclSuccess) {
WARN("NET/Socket : socket progress error");
return NULL;
}
idle = 0;
if(r->offset < r->size)
repeat = 1;
}
}
} while(repeat);
}
if(idle) {
pthread_mutex_lock(&resource->threadLock);
while(mark == myQueue->next && resource->stop == 0) { // no new tasks, wait
pthread_cond_wait(&resource->threadCond, &resource->threadLock);
}
pthread_mutex_unlock(&resource->threadLock);
}
if(resource->stop)
return NULL;
}
}
scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req) {
int tid = comm->nextSock % comm->nThreads;
struct scclNetSocketThreadResources* res = comm->threadResources + tid;
struct scclNetSocketTaskQueue* queue = &res->threadTaskQueue;
// create helper threads and prepare per-thread task queue
if(queue->tasks == NULL) {
// each request can be divided up to nSocks tasks, and
// these tasks are distributed to nThreads threads,
// we need to make sure each thread queue has enough slots for MAX_REQUESTS
queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads);
SCCLCHECK(scclCalloc(&queue->tasks, queue->len));
queue->next = 0;
res->comm = comm;
pthread_mutex_init(&res->threadLock, NULL);
pthread_cond_init(&res->threadCond, NULL);
pthread_create(comm->helperThread + tid, NULL, persistentSocketThread, res);
scclSetThreadName(comm->helperThread[tid], "NCCL Sock%c%1u%2u%2u", op == SCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->hipDev);
}
struct scclNetSocketTask* r = queue->tasks + queue->next;
if(r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
r->sock = comm->socks + comm->nextSock;
r->offset = 0;
r->result = scclSuccess;
comm->nextSock = (comm->nextSock + 1) % comm->nSocks;
r->used = 1;
*req = r;
pthread_mutex_lock(&res->threadLock);
queue->next = (queue->next + 1) % queue->len;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
return scclSuccess;
}
WARN("NET/Socket : unable to allocate subtasks");
return scclInternalError;
}
/**
* @brief 测试socket通信请求状态
*
* 该函数用于测试socket通信请求的完成状态,并处理数据传输过程。它会根据请求的不同状态(未开始、正在交换数据大小、已完成交换)执行相应的操作:
* - 如果请求未开始(used=0),则初始化状态
* - 如果正在交换数据大小(used=1),则处理数据大小交换逻辑
* - 如果已完成数据大小交换(used=2),则处理实际数据传输
*
* @param request 指向socket请求的指针
* @param done 输出参数,指示请求是否完成(1=完成,0=未完成)
* @param size 输出参数,返回传输的数据大小
* @return scclResult_t 返回操作结果状态码
*/
scclResult_t scclNetSocketTest(void* request, int* done, int* size) {
*done = 0;
struct scclNetSocketRequest* r = (struct scclNetSocketRequest*)request;
if(r == NULL) {
INFO(SCCL_LOG_CODEALL, "NET/Socket : test called with NULL request");
return scclInternalError;
}
INFO(SCCL_LOG_CODEALL, "NET/Socket : test called request used:%d\n", r->used);
if(r->used == 1) { /* try to send/recv size */
int data = r->size;
int offset = 0;
SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, &data, sizeof(int), &offset));
if(offset == 0)
return scclSuccess; /* Not ready -- retry later */
// Not sure we could ever receive less than 4 bytes, but just in case ...
if(offset < sizeof(int))
SCCLCHECK(scclSocketWait(r->op, r->ctrlSock, &data, sizeof(int), &offset));
// Check size is less or equal to the size provided by the user
if(r->op == SCCL_SOCKET_RECV && data > r->size) {
char line[SOCKET_NAME_MAXLEN + 1];
union scclSocketAddress addr;
scclSocketGetAddr(r->ctrlSock, &addr);
WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \
there may be a mismatch in collective sizes or environment settings (e.g. SCCL_PROTO, SCCL_ALGO) between ranks",
scclSocketToString(&addr, line),
data,
r->size);
return scclInvalidUsage;
}
r->size = data;
r->offset = 0;
r->used = 2; // done exchanging size
// divide into subtasks
int chunkOffset = 0, i = 0;
if(r->comm->nSocks > 0) {
// each request can be divided up to nSocks tasks
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
while(chunkOffset < r->size) {
int chunkSize = std::min(taskSize, r->size - chunkOffset);
SCCLCHECK(scclNetSocketGetTask(r->comm, r->op, (char*)(r->data) + chunkOffset, chunkSize, r->tasks + i++));
chunkOffset += chunkSize;
}
}
r->nSubs = i;
}
if(r->used == 2) { // already exchanged size
if(r->nSubs > 0) {
int nCompleted = 0;
for(int i = 0; i < r->nSubs; i++) {
struct scclNetSocketTask* sub = r->tasks[i];
if(sub->result != scclSuccess)
return sub->result;
if(sub->offset == sub->size)
nCompleted++;
}
if(nCompleted == r->nSubs) {
if(size)
*size = r->size;
*done = 1;
r->used = 0;
for(int i = 0; i < r->nSubs; i++) {
struct scclNetSocketTask* sub = r->tasks[i];
sub->used = 0;
}
}
} else { // progress request using main thread
if(r->offset < r->size) {
SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, r->data, r->size, &r->offset));
}
if(r->offset == r->size) {
if(size)
*size = r->size;
*done = 1;
r->used = 0;
}
}
}
return scclSuccess;
}
int main(int argc, char* argv[]) {
struct scclNetSocketRequest* request = (struct scclNetSocketRequest*)malloc(sizeof(struct scclNetSocketRequest));
request->op = SCCL_SOCKET_SEND;
request->used = 1;
request->size = 1024;
request->data = (char*)malloc(request->size);
request->ctrlSock = NULL;
request->comm = NULL;
request->nSubs = 0;
int done;
int sizes[32];
printf("test\n");
INFO(SCCL_LOG_CODEALL, "test INFO");
SCCLCHECK(scclSocketInit(request));
SCCLCHECK(scclNetSocketTest(request, &done, sizes));
if(done) {
printf("done\n");
}
}
\ No newline at end of file
hipcc ./test_topo.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo/topo.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo/xml.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo/rocm_smi_wrap.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/nvmlwrap.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvsymbols.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ibvwrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/net_ib.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/net_socket.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/net_utils.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/rocm_wrap.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/param.cpp \
-o test_topo \
-std=c++17 -g -O3 -fopenmp -D__HIP_PLATFORM_HCC__ \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/device/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/host/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/net/ \
-L /public/home/lishen/Code/rocSHMEM/SCCL_v1 \
-L /usr/lib/x86_64-linux-gnu -L /usr/lib/ \
-libverbs -lrdmacm -lamdhip64 -lrocm_smi64
hipcc ./test_xml.cpp \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo/topo.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo/xml.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo/rocm_smi_wrap.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/utils.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/archinfo.cc \
/public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/nvmlwrap.cc \
-o test_xml \
-std=c++17 -g -O3 -fopenmp -D__HIP_PLATFORM_HCC__ \
-I ./ -I /usr/include -I /opt/dtk/include \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/include \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/ \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/hardware/topology/topo \
-I /public/home/lishen/Code/rocSHMEM/SCCL_v1/src/utils/ \
-L /usr/lib/x86_64-linux-gnu \
-L /usr/lib/ \
-lamdhip64 -lrocm_smi64
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment