Commit d9d23f34 authored by lishen's avatar lishen
Browse files

Initial Code for SCCL_v1

parent 57df3737
/**
* MIT License
*
* Copyright 2019-2020 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
/*!\file
* \brief sccl_bfloat16.h provides struct for sccl_bfloat16 typedef
*/
#ifndef _SCCL_BFLOAT16_H_
#define _SCCL_BFLOAT16_H_
#if __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__))
// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
// include a minimal definition of sccl_bfloat16
#include <stdint.h>
/*! \brief Struct to represent a 16 bit brain floating point number. */
namespace sccl {
typedef struct {
uint16_t data;
} sccl_bfloat16;
} // namespace sccl
#else // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__) && !defined(__HIP_PLATFORM_HCC__))
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <hip/hip_runtime.h>
#include <ostream>
#include <type_traits>
namespace sccl {
struct sccl_bfloat16 {
uint16_t data;
enum truncate_t {
truncate
};
__host__ __device__ sccl_bfloat16() = default;
// round upper 16 bits of IEEE float to convert to bfloat16
explicit __host__ __device__ sccl_bfloat16(float f) : data(float_to_bfloat16(f)) {}
explicit __host__ __device__ sccl_bfloat16(float f, truncate_t) : data(truncate_float_to_bfloat16(f)) {}
// zero extend lower 16 bits of bfloat16 to convert to IEEE float
__host__ __device__ operator float() const {
union {
uint32_t int32;
float fp32;
} u = {uint32_t(data) << 16};
return u.fp32;
}
private:
static __host__ __device__ uint16_t float_to_bfloat16(float f) {
union {
float fp32;
uint32_t int32;
} u = {f};
if(~u.int32 & 0x7f800000) {
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
} else if(u.int32 & 0xffff) {
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
u.int32 |= 0x10000; // Preserve signaling NaN
}
return uint16_t(u.int32 >> 16);
}
// Truncate instead of rounding, preserving SNaN
static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f) {
union {
float fp32;
uint32_t int32;
} u = {f};
return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
}
};
typedef struct {
uint16_t data;
} sccl_bfloat16_public;
static_assert(std::is_standard_layout<sccl_bfloat16>{},
"sccl_bfloat16 is not a standard layout type, and thus is "
"incompatible with C.");
static_assert(std::is_trivial<sccl_bfloat16>{},
"sccl_bfloat16 is not a trivial type, and thus is "
"incompatible with C.");
static_assert(sizeof(sccl_bfloat16) == sizeof(sccl_bfloat16_public) && offsetof(sccl_bfloat16, data) == offsetof(sccl_bfloat16_public, data),
"internal sccl_bfloat16 does not match public sccl_bfloat16");
inline std::ostream& operator<<(std::ostream& os, const sccl_bfloat16& bf16) { return os << float(bf16); }
inline __host__ __device__ sccl_bfloat16 operator+(sccl_bfloat16 a) { return a; }
inline __host__ __device__ sccl_bfloat16 operator-(sccl_bfloat16 a) {
a.data ^= 0x8000;
return a;
}
inline __host__ __device__ sccl_bfloat16 operator+(sccl_bfloat16 a, sccl_bfloat16 b) { return sccl_bfloat16(float(a) + float(b)); }
inline __host__ __device__ sccl_bfloat16 operator-(sccl_bfloat16 a, sccl_bfloat16 b) { return sccl_bfloat16(float(a) - float(b)); }
inline __host__ __device__ sccl_bfloat16 operator*(sccl_bfloat16 a, sccl_bfloat16 b) { return sccl_bfloat16(float(a) * float(b)); }
inline __host__ __device__ sccl_bfloat16 operator/(sccl_bfloat16 a, sccl_bfloat16 b) { return sccl_bfloat16(float(a) / float(b)); }
inline __host__ __device__ bool operator<(sccl_bfloat16 a, sccl_bfloat16 b) { return float(a) < float(b); }
inline __host__ __device__ bool operator==(sccl_bfloat16 a, sccl_bfloat16 b) { return float(a) == float(b); }
inline __host__ __device__ bool operator>(sccl_bfloat16 a, sccl_bfloat16 b) { return b < a; }
inline __host__ __device__ bool operator<=(sccl_bfloat16 a, sccl_bfloat16 b) { return !(a > b); }
inline __host__ __device__ bool operator!=(sccl_bfloat16 a, sccl_bfloat16 b) { return !(a == b); }
inline __host__ __device__ bool operator>=(sccl_bfloat16 a, sccl_bfloat16 b) { return !(a < b); }
inline __host__ __device__ sccl_bfloat16& operator+=(sccl_bfloat16& a, sccl_bfloat16 b) { return a = a + b; }
inline __host__ __device__ sccl_bfloat16& operator-=(sccl_bfloat16& a, sccl_bfloat16 b) { return a = a - b; }
inline __host__ __device__ sccl_bfloat16& operator*=(sccl_bfloat16& a, sccl_bfloat16 b) { return a = a * b; }
inline __host__ __device__ sccl_bfloat16& operator/=(sccl_bfloat16& a, sccl_bfloat16 b) { return a = a / b; }
inline __host__ __device__ sccl_bfloat16& operator++(sccl_bfloat16& a) { return a += sccl_bfloat16(1.0f); }
inline __host__ __device__ sccl_bfloat16& operator--(sccl_bfloat16& a) { return a -= sccl_bfloat16(1.0f); }
inline __host__ __device__ sccl_bfloat16 operator++(sccl_bfloat16& a, int) {
sccl_bfloat16 orig = a;
++a;
return orig;
}
inline __host__ __device__ sccl_bfloat16 operator--(sccl_bfloat16& a, int) {
sccl_bfloat16 orig = a;
--a;
return orig;
}
namespace std {
constexpr __host__ __device__ bool isinf(sccl_bfloat16 a) { return !(~a.data & 0x7f80) && !(a.data & 0x7f); }
constexpr __host__ __device__ bool isnan(sccl_bfloat16 a) { return !(~a.data & 0x7f80) && +(a.data & 0x7f); }
constexpr __host__ __device__ bool iszero(sccl_bfloat16 a) { return !(a.data & 0x7fff); }
inline sccl_bfloat16 sin(sccl_bfloat16 a) { return sccl_bfloat16(sinf(float(a))); }
inline sccl_bfloat16 cos(sccl_bfloat16 a) { return sccl_bfloat16(cosf(float(a))); }
} // namespace std
} // namespace sccl
#endif // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__))
#endif // _SCCL_BFLOAT16_H_
#include "core.h"
#include "graph.h"
#include "topo.h"
#include "xml.h"
#include <math.h>
#include <sys/time.h>
#include "rome_models.h"
namespace sccl {
namespace hardware {
namespace topology {
namespace detect {
SCCL_PARAM(CrossNic, "CROSS_NIC", 2);
// Initialize system->maxBw. This is the per-channel (i.e. per-SM)
// max bw.
static float getMaxBw(struct scclTopoSystem* system, struct scclTopoNode* gpu, int type) {
float maxBw = 0.0;
for(int i = 0; i < system->nodes[type].count; i++) {
struct scclTopoLinkList* path = gpu->paths[type] + i;
float bw = path->bw;
if(path->count == 0)
continue;
maxBw = std::max(maxBw, bw);
}
return maxBw;
}
static float getTotalBw(struct scclTopoSystem* system, struct scclTopoNode* gpu) {
float nvlinkBw = 0.0, pciBw = 0.0;
for(int l = 0; l < gpu->nlinks; l++) {
struct scclTopoLink* link = gpu->links + l;
if(link->type == LINK_NVL)
nvlinkBw += link->bw;
if(link->type == LINK_PCI)
pciBw = link->bw;
}
return std::max(pciBw, nvlinkBw);
}
scclResult_t scclTopoSearchInit(struct scclTopoSystem* system) {
system->maxBw = 0.0;
system->totalBw = 0.0;
int inter = system->nodes[NET].count;
if(inter == 0 && system->nodes[GPU].count == 1) {
system->maxBw = LOC_BW;
return scclSuccess;
}
for(int g = 0; g < system->nodes[GPU].count; g++) {
struct scclTopoNode* gpu = system->nodes[GPU].nodes + g;
system->maxBw = std::max(system->maxBw, getMaxBw(system, gpu, inter ? NET : GPU));
system->totalBw = std::max(system->totalBw, getTotalBw(system, gpu));
}
return scclSuccess;
}
static scclResult_t findRevLink(struct scclTopoNode* node1, struct scclTopoNode* node2, struct scclTopoLink** revLink) {
for(int l = 0; l < node2->nlinks; l++) {
struct scclTopoLink* link = node2->links + l;
if(link->remNode == node1) {
*revLink = link;
return scclSuccess;
}
}
WARN("Could not find rev link for %d/%ld -> %d/%ld", node1->type, node1->id, node2->type, node2->id);
return scclInternalError;
}
// This is unfortunately needed since manipulating floats often results in rounding errors.
#define SUB_ROUND(a, b) (a = roundf((a - b) * 1000) / 1000)
static scclResult_t followPath(struct scclTopoLinkList* path, struct scclTopoNode* start, int maxSteps, float bw, int* steps) {
float pciBw = bw;
for(int step = 0; step < path->count; step++) {
struct scclTopoNode* node = path->list[step]->remNode;
if(node->type == CPU) {
// Account for P2P inefficiency through Intel CPU RC
if(path->type == PATH_PHB && start->type == GPU && node->cpu.arch == SCCL_TOPO_CPU_ARCH_X86 && node->cpu.vendor == SCCL_TOPO_CPU_VENDOR_INTEL) {
pciBw = INTEL_P2P_OVERHEAD(bw);
}
}
}
struct scclTopoNode* node = start;
for(int step = 0; step < maxSteps; step++) {
struct scclTopoLink* link = path->list[step];
struct scclTopoLink* revLink = NULL;
float fwBw = link->type == LINK_PCI ? pciBw : bw;
float revBw = 0;
if(link->remNode->type == GPU && link->remNode->gpu.cudaCompCap < 80 && start->type != GPU) {
if(revLink == NULL)
SCCLCHECK(findRevLink(node, link->remNode, &revLink));
revBw += fwBw / 8;
}
if(link->remNode->type == CPU && link->type == LINK_NVL) {
if(revLink == NULL)
SCCLCHECK(findRevLink(node, link->remNode, &revLink));
revBw += fwBw;
}
if(link->bw < fwBw || (revBw && revLink->bw < revBw)) {
*steps = step;
return scclSuccess;
}
SUB_ROUND(link->bw, fwBw);
if(revBw)
SUB_ROUND(revLink->bw, revBw);
node = link->remNode;
}
*steps = maxSteps;
return scclSuccess;
}
// Try to go from node type1/index1 to no type2/index2. mult indicates whether we are counting the bandwidth (1) or undoing (-1).
static scclResult_t scclTopoFollowPath(
struct scclTopoSystem* system, struct scclTopoGraph* graph, int type1, int index1, int type2, int index2, int mult, struct scclTopoNode** node) {
// First handle easy cases
*node = system->nodes[type2].nodes + index2;
if(type1 == -1)
return scclSuccess;
struct scclTopoNode* node1 = system->nodes[type1].nodes + index1;
struct scclTopoLinkList* path = node1->paths[type2] + index2;
struct scclTopoNode* node2 = system->nodes[type2].nodes + index2;
struct scclTopoLinkList* revPath = node2->paths[type1] + index1;
if(path == NULL) {
WARN("No path computed to go from %s/%d to %s/%d", topoNodeTypeStr[type1], index1, topoNodeTypeStr[type2], index2);
return scclInternalError;
}
if(path->count == 0)
return scclSuccess;
// Now check link type
*node = NULL;
int intra = (type1 == GPU || type1 == NVS) && (type2 == GPU || type2 == NVS);
float bw = intra ? graph->bwIntra : graph->bwInter;
int type = intra ? graph->typeIntra : graph->typeInter;
if(mult == 1 && (path->type > type))
return scclSuccess;
if(mult == 1 &&
(graph->pattern == SCCL_TOPO_PATTERN_BALANCED_TREE || graph->pattern == SCCL_TOPO_PATTERN_TREE || graph->pattern == SCCL_TOPO_PATTERN_SPLIT_TREE) &&
(revPath->type > type))
return scclSuccess;
bw *= mult;
// Check there is enough bandwidth on paths.
int step = 0;
SCCLCHECK(followPath(path, node1, path->count, bw, &step));
if(step < path->count)
goto rewind;
// Enough bandwidth : return destination node.
graph->nHops += mult * path->count;
*node = system->nodes[type2].nodes + index2;
return scclSuccess;
rewind:
// Not enough bandwidth : rewind and exit.
SCCLCHECK(followPath(path, node1, step, -bw, &step));
return scclSuccess;
}
static int gpuPciBw(struct scclTopoNode* gpu) {
for(int l = 0; l < gpu->nlinks; l++) {
struct scclTopoLink* gpuLink = gpu->links + l;
if(gpuLink->type != LINK_PCI)
continue;
struct scclTopoNode* pci = gpuLink->remNode;
for(int l = 0; l < pci->nlinks; l++) {
struct scclTopoLink* pciLink = pci->links + l;
if(pciLink->remNode != gpu)
continue;
return std::min(gpuLink->bw, pciLink->bw);
}
}
return -1;
}
/* Choose the order in which we try next GPUs. This is critical for the search
to quickly converge to the best solution even if it eventually times out. */
struct scclGpuScore {
int g; // Retain the index
int startIndex; // Least important
int intraNhops;
int intraBw;
int interNhops;
int interPciBw;
int interBw; // Most important
};
static int cmpScore(const void* g1, const void* g2) {
struct scclGpuScore* s1 = (struct scclGpuScore*)g1;
struct scclGpuScore* s2 = (struct scclGpuScore*)g2;
int d;
if((d = (s2->interBw - s1->interBw)))
return d;
if((d = (s2->interPciBw - s1->interPciBw)))
return d;
if((d = (s1->interNhops - s2->interNhops)))
return d;
if((d = (s2->startIndex - s1->startIndex)))
return d;
if((d = (s2->intraBw - s1->intraBw)))
return d;
if((d = (s1->intraNhops - s2->intraNhops)))
return d;
return s1->startIndex - s2->startIndex;
}
static int cmpIntraScores(struct scclGpuScore* scores, int count) {
int intraBw = scores[0].intraBw;
int intraNhops = scores[0].intraNhops;
for(int i = 1; i < count; i++) {
if(scores[i].intraBw != intraBw || scores[i].intraNhops != intraNhops)
return 1;
}
return 0;
}
static scclResult_t getGpuIndex(struct scclTopoSystem* system, int rank, int* index) {
for(int g = 0; g < system->nodes[GPU].count; g++) {
if(system->nodes[GPU].nodes[g].gpu.rank == rank) {
*index = g;
return scclSuccess;
}
}
WARN("Could not find gpu rank %d", rank);
return scclInternalError;
}
static scclResult_t getNetIndex(struct scclTopoSystem* system, int64_t id, int* index) {
for(int n = 0; n < system->nodes[NET].count; n++) {
if(system->nodes[NET].nodes[n].id == id) {
*index = n;
return scclSuccess;
}
}
WARN("Could not find net id %lx", id);
return scclInternalError;
}
static scclResult_t getNetPaths(struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoLinkList** netPaths) {
int netId = graph->inter[graph->nChannels * 2];
int n;
SCCLCHECK(getNetIndex(system, netId, &n));
*netPaths = system->nodes[NET].nodes[n].paths[GPU];
return scclSuccess;
}
scclResult_t
scclTopoSearchNextGpuSort(struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoNode* gpu, int* next, int* countPtr, int sortNet) {
const uint64_t flag = 1ULL << (graph->nChannels);
int ngpus = system->nodes[GPU].count;
struct scclTopoLinkList* paths = gpu->paths[GPU];
struct scclTopoLinkList* netPaths = NULL;
if(sortNet)
SCCLCHECK(getNetPaths(system, graph, &netPaths));
struct scclGpuScore scores[SCCL_TOPO_MAX_NODES];
memset(scores, 0, ngpus * sizeof(struct scclGpuScore));
int start = gpu - system->nodes[GPU].nodes;
int count = 0;
for(int i = 1; i < ngpus; i++) {
int g = (start + i) % ngpus;
if(paths[g].count == 0)
continue; // There is no path to that GPU
if(system->nodes[GPU].nodes[g].used & flag)
continue;
scores[count].g = g;
scores[count].startIndex = i;
scores[count].intraNhops = paths[g].count;
scores[count].intraBw = paths[g].bw;
if(netPaths) {
scores[count].interNhops = netPaths[g].count;
scores[count].interPciBw = gpuPciBw(system->nodes[GPU].nodes + g);
scores[count].interBw = netPaths[g].bw;
}
count++;
}
// Sort GPUs
qsort(scores, count, sizeof(struct scclGpuScore), cmpScore);
// Check if all have the same intra-node score in which case we go reverse for sortNet = -1
if(sortNet == -1 && cmpIntraScores(scores, count) == 0) {
for(int i = 0; i < count; i++)
next[i] = scores[count - 1 - i].g;
} else {
for(int i = 0; i < count; i++)
next[i] = scores[i].g;
}
*countPtr = count;
return scclSuccess;
}
scclResult_t scclTopoSearchRec(struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoGraph* saveGraph, int* time);
// Try to keep all searchs within one second
#define SCCL_SEARCH_GLOBAL_TIMEOUT (5ULL << 16)
#define SCCL_SEARCH_TIMEOUT (1 << 14)
#define SCCL_SEARCH_TIMEOUT_TREE (1 << 14)
#define SCCL_SEARCH_TIMEOUT_SAMECHANNELS (1 << 8)
#define FORCED_ORDER_PCI 1
#define FORCED_ORDER_REPLAY 2
scclResult_t scclTopoReplayGetGpu(struct scclTopoSystem* system, struct scclTopoGraph* graph, int step, int* g) {
*g = -1;
if(graph->nChannels == 0)
return scclInternalError;
int ngpus = system->nodes[GPU].count;
int nextRank = graph->intra[(graph->nChannels - 1) * ngpus + step + 1];
for(int i = 0; i < ngpus; i++)
if(system->nodes[GPU].nodes[i].gpu.rank == nextRank) {
*g = i;
return scclSuccess;
}
if(*g == -1)
return scclInternalError;
return scclSuccess;
}
scclResult_t scclTopoSearchRecGpu(struct scclTopoSystem* system,
struct scclTopoGraph* graph,
struct scclTopoGraph* saveGraph,
struct scclTopoNode* gpu,
int step,
int backToNet,
int backToFirstRank,
int forcedOrder,
int* time);
scclResult_t scclTopoSearchTryGpu(struct scclTopoSystem* system,
struct scclTopoGraph* graph,
struct scclTopoGraph* saveGraph,
int step,
int backToNet,
int backToFirstRank,
int forcedOrder,
int* time,
int type,
int index,
int g) {
const uint64_t flag = 1ULL << (graph->nChannels);
struct scclTopoNode* gpu;
SCCLCHECK(scclTopoFollowPath(system, graph, type, index, GPU, g, 1, &gpu));
if(gpu) {
gpu->used ^= flag;
SCCLCHECK(scclTopoSearchRecGpu(system, graph, saveGraph, gpu, step, backToNet, backToFirstRank, forcedOrder, time));
gpu->used ^= flag;
SCCLCHECK(scclTopoFollowPath(system, graph, type, index, GPU, g, -1, &gpu));
}
return scclSuccess;
}
static int scclTopoCountXGMI(struct scclTopoSystem* system, struct scclTopoGraph* graph) {
int ngpus = system->nodes[GPU].count;
int count = 0;
for(int c = 0; c < graph->nChannels; c++) {
for(int i = 0; i < ngpus; i++) {
int g = graph->intra[ngpus * c + i];
int n = graph->intra[ngpus * c + ((i + 1) % ngpus)];
struct scclTopoNode* node;
int j;
for(j = 0; j < ngpus; j++)
if(system->nodes[GPU].nodes[j].gpu.rank == g)
break;
if(j < ngpus) {
node = system->nodes[GPU].nodes + j;
for(int k = 0; k < system->nodes[GPU].count; k++) {
if(node->paths[GPU][k].count == 1) {
struct scclTopoLink* link = node->paths[GPU][k].list[0];
struct scclTopoNode* remNode = link->remNode;
if(remNode->gpu.rank == n) {
if(link->type == LINK_NVL)
count++;
}
}
}
}
}
}
return count;
}
scclResult_t scclTopoSearchTryNvls(struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoGraph* saveGraph, int g, int ngpus, int* time) {
struct scclTopoNode* nvs;
struct scclTopoNode* gpu;
int d0 = 0; // See if there is enough bandwidth for NVS->GPU traffic
do {
SCCLCHECK(scclTopoFollowPath(system, graph, NVS, 0, GPU, d0, d0 == g ? 2 : 1, &gpu));
d0++;
} while(gpu && d0 < system->nodes[GPU].count);
if(gpu == NULL) {
d0--;
} else {
int d1 = 0; // See if there is enough bandwidth for GPU->NVS traffic
do {
SCCLCHECK(scclTopoFollowPath(system, graph, GPU, d1, NVS, 0, d1 == g ? 2 : 1, &nvs));
d1++;
} while(nvs && d1 < system->nodes[GPU].count);
if(nvs == NULL) {
d1--;
} else { // Both directions worked. Move on to the next path.
SCCLCHECK(scclTopoSearchRecGpu(system, graph, saveGraph, NULL, ngpus, -1, -1, 0, time));
}
while(d1) {
d1--;
SCCLCHECK(scclTopoFollowPath(system, graph, GPU, d1, NVS, 0, d1 == g ? -2 : -1, &nvs));
}
}
while(d0) {
d0--;
SCCLCHECK(scclTopoFollowPath(system, graph, NVS, 0, GPU, d0, d0 == g ? -2 : -1, &gpu));
}
return scclSuccess;
}
scclResult_t scclTopoCompareGraphs(struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoGraph* refGraph, int* copy) {
// 1. Try to get the same nChannels between Rings and Trees
if(graph->nChannels < graph->minChannels)
return scclSuccess;
if(graph->pattern == SCCL_TOPO_PATTERN_NVLS) { // NVLS channels correspond to GPUs pulling from NVLS. So the more the better.
if(graph->nChannels > refGraph->nChannels && graph->nChannels <= system->nodes[GPU].count)
*copy = 1;
return scclSuccess;
}
// 2. Try to get better bandwidth
// Give a 15% perf bonus to paths not crossing nics
float target = 1.0 - (refGraph->crossNic - graph->crossNic) * .15;
if(graph->nChannels * graph->bwIntra > refGraph->nChannels * refGraph->bwIntra * target) {
*copy = 1;
return scclSuccess;
}
if(graph->nChannels * graph->bwIntra < refGraph->nChannels * refGraph->bwIntra * target)
return scclSuccess;
// 3. Less hops
if(graph->pattern == refGraph->pattern && graph->crossNic == refGraph->crossNic && graph->nHops < refGraph->nHops)
*copy = 1;
// 4. Prefer graph with more XGMI connections
if(graph->nChannels == refGraph->nChannels && scclTopoCountXGMI(system, refGraph) < scclTopoCountXGMI(system, graph))
*copy = 1;
return scclSuccess;
}
// Build a list of the best NETs to try.
//
// "gpu" can be set to -1 to build a list suitable for all GPUs (search start) or to a given gpu
// index when trying to get back to the NIC.
//
// The list is built the following way:
// 1. Select NETs starting with those close to GPU(s), based on paths[n].type.
// 2. For each GPU, once that list of NICs with a given distance is prepared, shuffle the list
// based on the GPU NVML index so that e.g. GPU 1 chooses NIC 1 first instead of NIC 0 which
// might have been choosen by GPU 0 (case with multiple independent communicators per node)
// 3. Then add the NETs to the final list if they were not already added by another closer GPU.
scclResult_t scclTopoSelectNets(struct scclTopoSystem* system, int typeInter, int gpu, int* nets, int* netCountRet) {
int netCount = 0;
int localNetCount;
int* localNets;
SCCLCHECK(scclCalloc(&localNets, system->nodes[NET].count));
for(int t = 0; t <= typeInter; t++) {
for(int g = 0; g < system->nodes[GPU].count; g++) {
if(gpu != -1 && gpu != g)
continue;
localNetCount = 0;
struct scclTopoNode* gpu = system->nodes[GPU].nodes + g;
struct scclTopoLinkList* paths = gpu->paths[NET];
for(int n = 0; n < system->nodes[NET].count; n++) {
if(paths[n].type == t)
localNets[localNetCount++] = n;
}
if(localNetCount == 0)
continue;
// Shuffle by gpu NVML device number so that GPUs on the same PCI switch
// with multiple NICs don't use the same one as first choice.
for(int r = 0; r < system->nodes[GPU].nodes[g].gpu.dev % localNetCount; r++) {
int net0 = localNets[0];
for(int i = 0; i < localNetCount - 1; i++)
localNets[i] = localNets[i + 1];
localNets[localNetCount - 1] = net0;
}
// Append NICs to list
for(int i = 0; i < localNetCount; i++) {
int n = localNets[i];
int found = 0;
while(nets[found] != n && found < netCount)
found++;
if(found == netCount)
nets[netCount++] = n;
}
}
}
*netCountRet = netCount;
free(localNets);
return scclSuccess;
}
scclResult_t scclTopoSearchRecGpu(struct scclTopoSystem* system,
struct scclTopoGraph* graph,
struct scclTopoGraph* saveGraph,
struct scclTopoNode* gpu,
int step,
int backToNet,
int backToFirstRank,
int forcedOrder,
int* time) {
if((*time) <= 0)
return scclSuccess;
(*time)--;
int ngpus = system->nodes[GPU].count;
if(step == ngpus) {
// Determine whether we found a better solution or not
int copy = 0;
graph->nChannels++;
SCCLCHECK(scclTopoCompareGraphs(system, graph, saveGraph, &copy));
if(copy) {
memcpy(saveGraph, graph, sizeof(struct scclTopoGraph));
if(graph->nChannels == graph->maxChannels)
*time = -1;
}
if(graph->nChannels < graph->maxChannels) {
SCCLCHECK(scclTopoSearchRec(system, graph, saveGraph, time));
}
graph->nChannels--;
return scclSuccess;
}
graph->intra[graph->nChannels * ngpus + step] = gpu->gpu.rank;
int g = gpu - system->nodes[GPU].nodes;
if(step == backToNet) {
// first get back to NIC
if(system->nodes[NET].count) {
int startNetIndex;
SCCLCHECK(getNetIndex(system, graph->inter[graph->nChannels * 2], &startNetIndex));
struct scclTopoNode* startNet = system->nodes[NET].nodes + startNetIndex;
int netcount;
int* nets;
SCCLCHECK(scclCalloc(&nets, system->nodes[NET].count));
SCCLCHECK(scclTopoSelectNets(system, graph->typeInter, g, nets, &netcount));
for(int i = 0; i < netcount; i++) {
int n = nets[i];
struct scclTopoNode* net = system->nodes[NET].nodes + n;
if(graph->pattern == SCCL_TOPO_PATTERN_TREE && net->id != startNet->id)
continue; // Trees are symmetric
if(graph->crossNic != 1 && (net->net.asic != startNet->net.asic || net->net.port != startNet->net.port))
continue;
// Balanced Tree : count half of the bandwidth on first two GPUs
int nextBackToNet = -1;
float bwInterSave = graph->bwInter;
if(graph->pattern == SCCL_TOPO_PATTERN_BALANCED_TREE) {
// Count half of the bandwidth on each of the first two GPUs
if(step == 0)
nextBackToNet = 1;
else if(net->id != graph->inter[graph->nChannels * 2 + 1])
continue;
graph->bwInter /= 2;
}
SCCLCHECK(scclTopoFollowPath(system, graph, GPU, g, NET, n, 1, &net));
graph->bwInter = bwInterSave;
if(net) {
graph->inter[graph->nChannels * 2 + 1] = net->id;
SCCLCHECK(scclTopoSearchRecGpu(system, graph, saveGraph, gpu, step, nextBackToNet, backToFirstRank, forcedOrder, time));
if(graph->pattern == SCCL_TOPO_PATTERN_BALANCED_TREE)
graph->bwInter /= 2;
SCCLCHECK(scclTopoFollowPath(system, graph, GPU, g, NET, n, -1, &net));
graph->bwInter = bwInterSave;
}
}
free(nets);
}
} else if(graph->pattern == SCCL_TOPO_PATTERN_NVLS) {
SCCLCHECK(scclTopoSearchTryNvls(system, graph, saveGraph, g, ngpus, time));
} else if(step < system->nodes[GPU].count - 1) {
// Go to next GPU
int next[SCCL_TOPO_MAX_NODES];
int count;
if(forcedOrder == FORCED_ORDER_PCI) { // Try the PCI order
next[0] = step + 1;
count = 1;
} else if(forcedOrder == FORCED_ORDER_REPLAY) { // Try last channel order
SCCLCHECK(scclTopoReplayGetGpu(system, graph, step, next));
count = 1;
} else { // Normal search
SCCLCHECK(scclTopoSearchNextGpuSort(system, graph, gpu, next, &count, backToNet == -1 ? 0 : backToNet == step + 1 ? 1 : -1));
}
for(int i = 0; i < count; i++) {
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, step + 1, backToNet, backToFirstRank, forcedOrder, time, GPU, g, next[i]));
}
} else if(step == backToFirstRank) {
// Find first GPU and loop back to it
int p;
SCCLCHECK(getGpuIndex(system, graph->intra[graph->nChannels * ngpus], &p));
struct scclTopoNode* firstGpu;
SCCLCHECK(scclTopoFollowPath(system, graph, GPU, g, GPU, p, 1, &firstGpu));
if(firstGpu) {
SCCLCHECK(scclTopoSearchRecGpu(system, graph, saveGraph, firstGpu, step + 1, backToNet, -1, forcedOrder, time));
SCCLCHECK(scclTopoFollowPath(system, graph, GPU, g, GPU, p, -1, &firstGpu));
}
} else {
// Next path
SCCLCHECK(scclTopoSearchRecGpu(system, graph, saveGraph, gpu, ngpus, -1, -1, forcedOrder, time));
}
return scclSuccess;
}
scclResult_t scclTopoSearchRecNet(
struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoGraph* saveGraph, int backToNet, int backToFirstRank, int* time) {
const int bw = graph->bwInter;
int* nets;
SCCLCHECK(scclCalloc(&nets, system->nodes[NET].count));
int netcount;
SCCLCHECK(scclTopoSelectNets(system, graph->typeInter, -1, nets, &netcount));
for(int i = 0; i < netcount; i++) {
int n = nets[i];
struct scclTopoNode* net = system->nodes[NET].nodes + n;
struct scclTopoNode* gpu;
if(graph->collNet && net->net.collSupport == 0)
continue;
if(net->net.bw < bw)
continue;
graph->inter[graph->nChannels * 2] = net->id;
graph->latencyInter = net->net.latency;
for(int i = 0; i < system->nodes[NET].count; i++) {
if((system->nodes[NET].nodes[i].net.asic == net->net.asic) && (system->nodes[NET].nodes[i].net.port == net->net.port)) {
system->nodes[NET].nodes[i].net.bw -= bw;
}
}
// NVLS needs to balance on all NICs
if(graph->pattern == SCCL_TOPO_PATTERN_NVLS) {
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, -1, -1, nets[graph->nChannels]));
} else {
if(graph->nChannels > 0) {
// Try to replay the last channel
int g;
SCCLCHECK(scclTopoReplayGetGpu(system, graph, -1, &g));
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, FORCED_ORDER_REPLAY, time, NET, n, g));
}
if(graph->nChannels == 0 || graph->sameChannels == 0) {
if(graph->nChannels == 0) {
// Always try the PCI order first to set a reference, but don't count in the timeout nor let it run for long
struct scclTopoLinkList* paths = net->paths[GPU];
int f = 0, f_gdr = 0;
// find the first GPU that is closest to NIC
for(int i = 0; i < system->nodes[GPU].count; i++) {
if(paths[i].count <= paths[f].count) {
// prefer GPU direct RDMA
int gdr;
SCCLCHECK(scclTopoCheckGdr(system, system->nodes[GPU].nodes[i].id, net->id, 0, &gdr));
if(paths[i].count < paths[f].count || (paths[i].count == paths[f].count && !f_gdr && gdr)) {
f = i;
f_gdr = gdr;
}
}
}
int t = 1 << 10;
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, FORCED_ORDER_PCI, &t, NET, n, 0));
if(t == -1)
*time = -1;
}
// Then try the most local GPUs
float maxBw = 0;
int minHops = 0xfffffff;
struct scclTopoLinkList* paths = net->paths[GPU];
for(int g = 0; g < system->nodes[GPU].count; g++) {
if(paths[g].bw > maxBw) {
maxBw = paths[g].bw;
minHops = paths[g].count;
} else if(paths[g].bw == maxBw && paths[g].count < minHops) {
minHops = paths[g].count;
}
}
if(maxBw >= bw) {
// In the first loop, avoid using GPUs in both directions between channels (one channel
// sending from that GPU and one channel receiving to that GPU), since that usually leads
// to lower BW.
for(int tryGpuBidir = 0; tryGpuBidir < 2; tryGpuBidir++) {
for(int g = 0; g < system->nodes[GPU].count; g++) {
if(paths[g].bw == maxBw && paths[g].count == minHops) {
gpu = system->nodes[GPU].nodes + g;
int gpuUsed = gpuPciBw(gpu) > 0 ? 0 : 1;
if(tryGpuBidir == gpuUsed) {
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, NET, n, g));
}
}
}
}
}
}
}
for(int i = 0; i < system->nodes[NET].count; i++) {
if((system->nodes[NET].nodes[i].net.asic == net->net.asic) && (system->nodes[NET].nodes[i].net.port == net->net.port)) {
system->nodes[NET].nodes[i].net.bw += bw;
}
}
}
free(nets);
return scclSuccess;
}
/* Search Patterns
*
* Intra-node
* Ring : GPU a -> GPU b -> .. -> GPU x -> GPU a
* (=Split Tree Loop)
* Tree : GPU a -> GPU b -> .. -> GPU x
* (=Split Tree)
*
* Inter-node
* Ring : NET n -> GPU a -> GPU b -> .. -> GPU x -> NET n (or m if crossNic)
* Tree : NET n -> GPU a -> GPU b -> .. -> GPU x
* `--> NET n (or m if crossNic)
* Split Tree : NET n -> GPU a -> GPU b -> .. -> GPU x
* `--> NET n (or m if crossNic)
* Split Tree Loop : NET n -> GPU a -> GPU b -> .. -> GPU x -> GPU a
* `--> NET n (or m if crossNic)
*/
scclResult_t scclTopoSearchParams(struct scclTopoSystem* system, int pattern, int* backToNet, int* backToFirstRank) {
if(system->nodes[NET].count && system->nodes[GPU].count != system->nRanks) {
if(pattern == SCCL_TOPO_PATTERN_RING)
*backToNet = system->nodes[GPU].count - 1;
else if(pattern == SCCL_TOPO_PATTERN_SPLIT_TREE)
*backToNet = 1;
else
*backToNet = 0;
*backToFirstRank = -1;
} else {
*backToNet = -1;
if(pattern == SCCL_TOPO_PATTERN_RING)
*backToFirstRank = system->nodes[GPU].count - 1;
else
*backToFirstRank = -1;
}
return scclSuccess;
}
scclResult_t scclTopoSearchRec(struct scclTopoSystem* system, struct scclTopoGraph* graph, struct scclTopoGraph* saveGraph, int* time) {
int backToNet, backToFirstRank;
SCCLCHECK(scclTopoSearchParams(system, graph->pattern, &backToNet, &backToFirstRank));
if(system->nodes[NET].count && system->nodes[GPU].count != system->nRanks) {
// Start from NET
scclTopoSearchRecNet(system, graph, saveGraph, backToNet, backToFirstRank, time);
} else {
// Intra-node only.
if(graph->pattern == SCCL_TOPO_PATTERN_NVLS) {
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, -1, -1, graph->nChannels));
return scclSuccess;
} else if(graph->nChannels == 0) {
// Try PCI order first
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, FORCED_ORDER_PCI, time, -1, -1, 0));
} else {
// Also try to replay previous channel
int g;
SCCLCHECK(scclTopoReplayGetGpu(system, graph, -1, &g));
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, -1, -1, g));
}
if(graph->sameChannels == 0 || graph->nChannels == 0) {
// Finally, try all other possibilities unless we are forced to use the same channels
for(int g = 0; g < system->nodes[GPU].count; g++) {
SCCLCHECK(scclTopoSearchTryGpu(system, graph, saveGraph, 0, backToNet, backToFirstRank, 0, time, -1, -1, g));
}
}
}
return scclSuccess;
}
/************************************/
/* User defined graph from XML file */
/************************************/
struct kvDict kvDictLinkType[] = {{"LOC", PATH_LOC},
{"NVL", PATH_NVL},
{"NVB", PATH_NVB},
{"PIX", PATH_PIX},
{"PXB", PATH_PXB},
{"PXN", PATH_PXN},
{"PHB", PATH_PHB},
{"SYS", PATH_SYS},
{NULL, 0}};
scclResult_t scclTopoGetChannelFromXml(struct scclXmlNode* xmlChannel, int c, struct scclTopoSystem* system, struct scclTopoGraph* graph) {
int ngpus = system->nodes[GPU].count;
int* inter = graph->inter + 2 * c;
int* intra = graph->intra + ngpus * c;
int n = 0, g = 0;
for(int s = 0; s < xmlChannel->nSubs; s++) {
struct scclXmlNode* sub = xmlChannel->subs[s];
int dev;
SCCLCHECK(xmlGetAttrInt(sub, "dev", &dev));
if(strcmp(sub->name, "net") == 0) {
inter[n++] = dev;
} else if(strcmp(sub->name, "gpu") == 0) {
int rank = -1;
for(int g = 0; g < ngpus; g++) {
if(system->nodes[GPU].nodes[g].gpu.dev == dev)
rank = system->nodes[GPU].nodes[g].gpu.rank;
}
if(rank == -1) {
WARN("XML Import Channel : dev %d not found.", dev);
return scclSystemError;
}
intra[g++] = rank;
}
}
return scclSuccess;
}
scclResult_t scclTopoGetGraphFromXmlSub(struct scclXmlNode* xmlGraph, struct scclTopoSystem* system, struct scclTopoGraph* graph, int* nChannels) {
int id;
SCCLCHECK(xmlGetAttrInt(xmlGraph, "id", &id));
if(graph->id != id)
return scclSuccess;
int crossNic;
SCCLCHECK(xmlGetAttrInt(xmlGraph, "crossnic", &crossNic));
if(scclParamCrossNic() == 0 && crossNic == 1)
return scclSuccess;
graph->crossNic = crossNic;
SCCLCHECK(xmlGetAttrInt(xmlGraph, "pattern", &graph->pattern));
SCCLCHECK(xmlGetAttrInt(xmlGraph, "nchannels", &graph->nChannels));
SCCLCHECK(xmlGetAttrFloat(xmlGraph, "speedintra", &graph->bwIntra));
SCCLCHECK(xmlGetAttrFloat(xmlGraph, "speedinter", &graph->bwInter));
if(xmlGetAttrFloat(xmlGraph, "latencyinter", &graph->latencyInter) != scclSuccess)
graph->latencyInter = 0.0;
const char* str;
SCCLCHECK(xmlGetAttr(xmlGraph, "typeintra", &str));
SCCLCHECK(kvConvertToInt(str, &graph->typeIntra, kvDictLinkType));
SCCLCHECK(xmlGetAttr(xmlGraph, "typeinter", &str));
SCCLCHECK(kvConvertToInt(str, &graph->typeInter, kvDictLinkType));
SCCLCHECK(xmlGetAttrInt(xmlGraph, "samechannels", &graph->sameChannels));
for(int s = 0; s < xmlGraph->nSubs; s++) {
SCCLCHECK(scclTopoGetChannelFromXml(xmlGraph->subs[s], s, system, graph));
}
*nChannels = xmlGraph->nSubs;
return scclSuccess;
}
scclResult_t scclTopoGetGraphFromXml(struct scclXmlNode* xmlGraphs, struct scclTopoSystem* system, struct scclTopoGraph* graph, int* nChannels) {
for(int s = 0; s < xmlGraphs->nSubs; s++) {
SCCLCHECK(scclTopoGetGraphFromXmlSub(xmlGraphs->subs[s], system, graph, nChannels));
}
return scclSuccess;
}
/* And the reverse : graph->xml */
scclResult_t scclTopoGetXmlFromChannel(struct scclTopoGraph* graph, int c, struct scclTopoSystem* system, struct scclXml* xml, struct scclXmlNode* parent) {
struct scclXmlNode* xmlChannel;
int ngpus = system->nodes[GPU].count;
int* inter = graph->inter + 2 * c;
int* intra = graph->intra + ngpus * c;
SCCLCHECK(xmlAddNode(xml, parent, "channel", &xmlChannel));
struct scclXmlNode* node;
if(system->nodes[NET].count) {
SCCLCHECK(xmlAddNode(xml, xmlChannel, "net", &node));
SCCLCHECK(xmlSetAttrInt(node, "dev", inter[0]));
}
for(int g = 0; g < ngpus; g++) {
SCCLCHECK(xmlAddNode(xml, xmlChannel, "gpu", &node));
int dev = -1;
for(int i = 0; i < ngpus; i++) {
if(system->nodes[GPU].nodes[i].gpu.rank == intra[g])
dev = system->nodes[GPU].nodes[i].gpu.dev;
}
if(dev == -1) {
WARN("XML Export Channel : rank %d not found.", intra[g]);
return scclInternalError;
}
SCCLCHECK(xmlSetAttrInt(node, "dev", dev));
}
if(system->nodes[NET].count) {
SCCLCHECK(xmlAddNode(xml, xmlChannel, "net", &node));
SCCLCHECK(xmlSetAttrInt(node, "dev", inter[1]));
}
return scclSuccess;
}
scclResult_t scclTopoGetXmlFromGraph(struct scclTopoGraph* graph, struct scclTopoSystem* system, struct scclXml* xml, struct scclXmlNode* parent) {
struct scclXmlNode* xmlGraph;
SCCLCHECK(xmlAddNode(xml, parent, "graph", &xmlGraph));
SCCLCHECK(xmlSetAttrInt(xmlGraph, "id", graph->id));
SCCLCHECK(xmlSetAttrInt(xmlGraph, "pattern", graph->pattern));
SCCLCHECK(xmlSetAttrInt(xmlGraph, "crossnic", graph->crossNic));
SCCLCHECK(xmlSetAttrInt(xmlGraph, "nchannels", graph->nChannels));
SCCLCHECK(xmlSetAttrFloat(xmlGraph, "speedintra", graph->bwIntra));
SCCLCHECK(xmlSetAttrFloat(xmlGraph, "speedinter", graph->bwInter));
SCCLCHECK(xmlSetAttrFloat(xmlGraph, "latencyinter", graph->latencyInter));
const char* str;
SCCLCHECK(kvConvertToStr(graph->typeIntra, &str, kvDictLinkType));
SCCLCHECK(xmlSetAttr(xmlGraph, "typeintra", str));
SCCLCHECK(kvConvertToStr(graph->typeInter, &str, kvDictLinkType));
SCCLCHECK(xmlSetAttr(xmlGraph, "typeinter", str));
SCCLCHECK(xmlSetAttrInt(xmlGraph, "samechannels", graph->sameChannels));
for(int c = 0; c < graph->nChannels; c++) {
SCCLCHECK(scclTopoGetXmlFromChannel(graph, c, system, xml, xmlGraph));
}
return scclSuccess;
}
scclResult_t scclTopoGetXmlFromGraphs(int ngraphs, struct scclTopoGraph** graphs, struct scclTopoSystem* system, struct scclXml* xml) {
xml->maxIndex = 0;
struct scclXmlNode* xmlGraphs;
SCCLCHECK(xmlAddNode(xml, NULL, "graphs", &xmlGraphs));
SCCLCHECK(xmlSetAttrInt(xmlGraphs, "version", SCCL_GRAPH_XML_VERSION));
for(int g = 0; g < ngraphs; g++) {
SCCLCHECK(scclTopoGetXmlFromGraph(graphs[g], system, xml, xmlGraphs));
}
return scclSuccess;
}
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
float speedArrayIntra[] = {48.0, 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12};
float speedArrayInter[] = {48.0, 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12};
#define NSPEEDSINTRA (sizeof(speedArrayIntra) / sizeof(float))
#define NSPEEDSINTER (sizeof(speedArrayInter) / sizeof(float))
#else
float speedArrayIntra[] = {40.0, 30.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0};
float speedArrayInter[] = {48.0, 30.0, 28.0, 24.0, 20.0, 18.0, 15.0, 12.0, 10.0, 9.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.4, 1.2, 0.24, 0.12};
#define NSPEEDSINTRA (sizeof(speedArrayIntra) / sizeof(float))
#define NSPEEDSINTER (sizeof(speedArrayInter) / sizeof(float))
float sm90SpeedArrayIntra[] = {60.0, 40.0, 30.0, 24.0, 20.0, 15.0, 12.0, 6.0, 3.0};
float sm90SpeedArrayInter[] = {48.0, 45.0, 42.0, 40.0, 30.0, 24.0, 20.0, 17.5, 15.0, 12.0, 6.0, 3.0, 2.4, 1.2, 0.24, 0.12};
#define NSPEEDSINTRA_SM90 (sizeof(sm90SpeedArrayIntra) / sizeof(float))
#define NSPEEDSINTER_SM90 (sizeof(sm90SpeedArrayInter) / sizeof(float))
#endif
RCCL_PARAM(ModelMatchingDisable, "MODEL_MATCHING_DISABLE", 0);
RCCL_PARAM(NChannels, "NCHANNELS", 0);
scclResult_t scclTopoCompute(scclTopoSystem* system, struct scclTopoGraph* graph) {
int ngpus = system->nodes[GPU].count;
graph->crossNic = scclParamCrossNic();
int crossNic = (system->nodes[NET].count > 1) && graph->crossNic &&
(graph->pattern == SCCL_TOPO_PATTERN_RING || graph->pattern == SCCL_TOPO_PATTERN_BALANCED_TREE ||
graph->pattern == SCCL_TOPO_PATTERN_SPLIT_TREE)
? 1
: 0;
graph->bwIntra = graph->bwInter = 0;
graph->latencyInter = 0;
if(graph->crossNic == 2)
graph->crossNic = 0;
graph->typeIntra = ngpus == 1 ? PATH_LOC : PATH_NVL;
graph->typeInter = PATH_PIX;
graph->nChannels = 0;
graph->nIntraChannels = 0;
memset(graph->intraNets, 0, MAXCHANNELS * SCCL_TOPO_MAX_NODES * 2 * sizeof(int));
int trySameChannels = graph->pattern == SCCL_TOPO_PATTERN_NVLS ? 0 : 1;
graph->sameChannels = trySameChannels;
char* str = getenv("SCCL_GRAPH_FILE");
if(str) {
INFO(SCCL_ENV, "SCCL_GRAPH_FILE set by environment to %s", str);
struct scclXml* xml;
SCCLCHECK(scclCalloc(&xml, 1));
SCCLCHECK(scclTopoGetXmlGraphFromFile(str, xml));
int nChannels;
SCCLCHECK(scclTopoGetGraphFromXml(xml->nodes, system, graph, &nChannels));
INFO(SCCL_GRAPH, "Search %d : %d channels loaded from XML graph", graph->id, nChannels);
free(xml);
if(graph->nChannels > 0)
return scclSuccess;
}
str = getenv("SCCL_RINGS");
char* strTrees = getenv("RCCL_TREES");
if(str || strTrees) {
// user supplied topo
if(strTrees) {
SCCLCHECK(parseGraphLight(strTrees, system, graph, NULL));
system->treeDefined = true;
} else {
SCCLCHECK(parseGraph(str, system, graph, NULL, NULL));
int arch, vendor, model;
SCCLCHECK(scclTopoCpuType(system, &arch, &vendor, &model));
if(graph->nChannels && arch == SCCL_TOPO_CPU_ARCH_X86 && vendor == SCCL_TOPO_CPU_VENDOR_AMD && model == SCCL_TOPO_CPU_TYPE_ROME) {
system->type |= RCCL_TOPO_4P2H_ROME;
}
}
} else if(!rcclParamModelMatchingDisable() && !graph->collNet) {
// try to match 8P6L
SCCLCHECK(parseChordalRing(system, graph));
if(graph->nChannels)
return scclSuccess;
// try to match Rome 4P2H
SCCLCHECK(parseRome4P2H(system, graph));
if(graph->nChannels)
return scclSuccess;
// try to match 1H16P
SCCLCHECK(parse1H16P(system, graph));
if(graph->nChannels)
return scclSuccess;
// try to match 4H4P
SCCLCHECK(parse4H4P(system, graph));
}
if(graph->nChannels)
return scclSuccess;
if((graph->pattern == SCCL_TOPO_PATTERN_RING) && (system->type & RCCL_TOPO_4P2H_ROME) && (ngpus == system->nRanks)) {
// limit single node max channels when searching ring graph on Rome
graph->maxChannels = 2;
}
if(ngpus == 1)
if(graph->pattern != SCCL_TOPO_PATTERN_RING)
graph->pattern = SCCL_TOPO_PATTERN_TREE;
int ccMin;
SCCLCHECK(scclTopoGetCompCap(system, &ccMin, NULL));
if(graph->pattern == SCCL_TOPO_PATTERN_NVLS && (system->nodes[NVS].count == 0 || ccMin < 90))
return scclSuccess;
if(ngpus == 1)
if(graph->pattern != SCCL_TOPO_PATTERN_RING)
graph->pattern = SCCL_TOPO_PATTERN_TREE;
if(system->nodes[NET].count == 0 && graph->pattern == SCCL_TOPO_PATTERN_NVLS) {
// Force intra-node NVLS algorithm to pull evenly from all GPUs.
graph->minChannels = graph->maxChannels = system->nodes[GPU].count;
}
struct scclTopoGraph tmpGraph;
memcpy(&tmpGraph, graph, sizeof(struct scclTopoGraph));
// First try crossnic, then decrease bw and finally increase bwIntra.
int nspeeds = 0;
float* speedArray = NULL;
if(system->nodes[NET].count == 0) {
nspeeds = NSPEEDSINTRA;
speedArray = speedArrayIntra;
} else {
nspeeds = NSPEEDSINTER;
speedArray = speedArrayInter;
}
int pass = 1;
int speedIndex = 0;
float maxBw = system->maxBw;
float totalBw = system->totalBw;
if(ngpus == 1 || graph->pattern != SCCL_TOPO_PATTERN_RING)
totalBw *= ngpus * 1.0 / (ngpus - 1);
while((speedArray[speedIndex] > maxBw || speedArray[speedIndex] * graph->minChannels > totalBw) && speedIndex < nspeeds - 1)
speedIndex++;
tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex];
int64_t globalTimeout = SCCL_SEARCH_GLOBAL_TIMEOUT;
search:
int time = tmpGraph.sameChannels ? SCCL_SEARCH_TIMEOUT_SAMECHANNELS
: tmpGraph.pattern == SCCL_TOPO_PATTERN_TREE ? SCCL_SEARCH_TIMEOUT_TREE
: SCCL_SEARCH_TIMEOUT;
tmpGraph.nChannels = 0;
globalTimeout -= time;
SCCLCHECK(scclTopoSearchRec(system, &tmpGraph, graph, &time));
#if 0
printf("Pattern %d, crossNic %d, Bw %g/%g, type %d/%d, channels %d-%d sameChannels %d -> nChannels %dx%g/%g %s\n", tmpGraph.pattern, tmpGraph.crossNic, tmpGraph.bwInter, tmpGraph.bwIntra, tmpGraph.typeInter, tmpGraph.typeIntra, tmpGraph.minChannels, tmpGraph.maxChannels, tmpGraph.sameChannels, graph->nChannels, graph->bwInter, graph->bwIntra, time == 0 ? "TIMEOUT" : time == -1 ? "PERFECT" : "");
for (int c=0; c<graph->nChannels; c++) {
printf("%2d : ", c);
for (int g=0; g<ngpus; g++) {
printf("%d ", graph->intra[c*ngpus+g]);
}
printf("[%d %d]", graph->inter[c*2+0], graph->inter[c*2+1]);
printf("\n");
}
#endif
// Optimal solution, stop here
if(time == -1)
goto done;
if(graph->nChannels * graph->bwInter >= system->totalBw)
goto done;
if(pass == 1) {
// First pass, we don't have a solution yet ; try other options
// Try having different channels
if(tmpGraph.sameChannels == 1) {
tmpGraph.sameChannels = 0;
goto search;
}
tmpGraph.sameChannels = trySameChannels;
if(time != -1)
globalTimeout += time;
else
globalTimeout = SCCL_SEARCH_GLOBAL_TIMEOUT;
if(globalTimeout < 0 && graph->nChannels)
goto done;
tmpGraph.pattern = graph->pattern;
int maxTypeIntra = system->nodes[NET].count > 0 ? tmpGraph.typeInter : PATH_SYS;
if(tmpGraph.typeIntra < maxTypeIntra && (graph->nChannels == 0 || tmpGraph.typeIntra < graph->typeIntra)) {
tmpGraph.typeIntra += 1;
goto search;
}
tmpGraph.typeIntra = ngpus == 1 ? PATH_LOC : PATH_NVL;
if(system->nodes[NET].count > 0 && tmpGraph.typeInter < PATH_SYS &&
(graph->nChannels == 0 || tmpGraph.typeInter < graph->typeInter || tmpGraph.typeInter < PATH_PXN)) {
tmpGraph.typeInter += 1;
goto search;
}
tmpGraph.typeInter = PATH_PIX;
if(crossNic && tmpGraph.crossNic == 0) {
// Try again with crossNic if permitted
tmpGraph.crossNic = crossNic;
goto search;
}
tmpGraph.crossNic = 0;
// Decrease bw until we find a solution
if((speedIndex < nspeeds - 1) && (graph->nChannels == 0 || (speedArray[speedIndex + 1] / graph->bwInter > .49))) {
tmpGraph.bwInter = tmpGraph.bwIntra = speedArray[++speedIndex];
goto search;
}
speedIndex = 0;
while(speedArray[speedIndex] > maxBw && speedIndex < nspeeds - 1)
speedIndex++;
tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex];
}
done:
// We have a solution. Start from that solution and move to pass 2.
if(pass == 1) {
time = -1;
memcpy(&tmpGraph, graph, sizeof(tmpGraph));
speedIndex = 0;
while(speedArray[speedIndex] > graph->bwInter && speedIndex < nspeeds - 1)
speedIndex++;
tmpGraph.bwIntra = tmpGraph.bwInter = speedArray[speedIndex];
tmpGraph.minChannels = graph->nChannels;
pass = 2;
}
// 3. See if we can increase bwIntra for trees (2 nodes or collnet)
if(pass == 2) {
if(time != 0 && graph->pattern != SCCL_TOPO_PATTERN_RING && tmpGraph.bwIntra == graph->bwIntra && tmpGraph.bwIntra < tmpGraph.bwInter * 2 &&
speedIndex > 0) {
tmpGraph.bwIntra = speedArray[--speedIndex];
goto search;
}
time = -1;
memcpy(&tmpGraph, graph, sizeof(tmpGraph));
}
if(graph->nChannels == 0 && graph->collNet == 0 && graph->pattern != SCCL_TOPO_PATTERN_NVLS) {
WARN("Could not find a path for pattern %d, falling back to simple order", graph->pattern);
for(int i = 0; i < ngpus; i++)
graph->intra[i] = system->nodes[GPU].nodes[i].gpu.rank;
graph->inter[0] = graph->inter[1] = 0;
graph->bwIntra = graph->bwInter = 0.1;
graph->typeIntra = graph->typeInter = PATH_SYS;
graph->nChannels = 1;
}
if(graph->nChannels == 0)
return scclSuccess;
if(graph->pattern == SCCL_TOPO_PATTERN_NVLS)
return scclSuccess;
if(graph->bwIntra < 25.0)
return scclSuccess;
if(ccMin > 80 && graph->bwIntra < 50.0 && graph->nChannels > 4)
return scclSuccess;
int dupChannels = std::min(graph->nChannels * 2, graph->maxChannels);
memcpy(graph->intra + graph->nChannels * ngpus, graph->intra, (dupChannels - graph->nChannels) * ngpus * sizeof(int));
memcpy(graph->inter + graph->nChannels * 2, graph->inter, (dupChannels - graph->nChannels) * 2 * sizeof(int));
graph->bwIntra /= DIVUP(dupChannels, graph->nChannels);
graph->bwInter /= DIVUP(dupChannels, graph->nChannels);
graph->nChannels = dupChannels;
int nc = rcclParamNChannels();
if(graph->nChannels > 0 && nc > 0 && nc <= MAXCHANNELS / 2 && nc > graph->nChannels) {
int nChannels = nc - graph->nChannels;
int nnets = system->nodes[NET].count;
if(nnets <= 2) {
for(int i = 0; i < nChannels; ++i) {
memcpy(graph->intra + graph->nChannels * ngpus, graph->intra, ngpus * sizeof(int));
memcpy(graph->inter + graph->nChannels * 2, graph->inter, 2 * sizeof(int));
memcpy(graph->intraNets + graph->nChannels * ngpus * 2, graph->intraNets, 2 * ngpus * sizeof(int));
graph->nChannels++;
}
} else {
typedef struct {
int id;
int used;
} Net;
Net nets[nnets];
auto sortFunc = [](const void* a, const void* b) -> int { return ((Net*)a)->used - ((Net*)b)->used; };
memset(nets, 0, nnets * sizeof(Net));
for(int i = 0; i < nnets; ++i) {
nets[i].id = system->nodes[NET].nodes[i].id;
}
for(int i = 0; i < graph->nChannels; ++i) {
for(int j = 0; j < nnets; ++j) {
if(nets[j].id == *(graph->inter + i * 2) || nets[j].id == *(graph->inter + i * 2 + 1)) {
nets[j].used++;
}
}
}
for(int i = 0; i < nChannels; ++i) {
memcpy(graph->intra + graph->nChannels * ngpus, graph->intra, ngpus * sizeof(int));
qsort(nets, nnets, sizeof(Net), sortFunc);
*(graph->inter + graph->nChannels * 2) = nets[0].id;
nets[0].used++;
qsort(nets, nnets, sizeof(Net), sortFunc);
if(graph->crossNic == 0 || graph->crossNic == 2) {
*(graph->inter + graph->nChannels * 2 + 1) = nets[0].id;
nets[0].used++;
qsort(nets, nnets, sizeof(Net), sortFunc);
} else {
nets[0].used++;
qsort(nets, nnets, sizeof(Net), sortFunc);
*(graph->inter + graph->nChannels * 2 + 1) = nets[0].id;
}
nets[0].used++;
memcpy(graph->intraNets + graph->nChannels * ngpus * 2, graph->intraNets, 2 * ngpus * sizeof(int));
graph->nChannels++;
}
}
graph->bwIntra /= DIVUP(nc, graph->nChannels);
graph->bwInter /= DIVUP(nc, graph->nChannels);
}
return scclSuccess;
}
scclResult_t scclTopoPrintGraph(struct scclTopoSystem* system, struct scclTopoGraph* graph) {
INFO(SCCL_GRAPH,
"Pattern %d, crossNic %d, nChannels %d, bw %f/%f, type %s/%s, sameChannels %d",
graph->pattern,
graph->crossNic,
graph->nChannels,
graph->bwIntra,
graph->bwInter,
topoPathTypeStr[graph->typeIntra],
topoPathTypeStr[graph->typeInter],
graph->sameChannels);
int ngpus = system->nodes[GPU].count;
char line[1024];
for(int c = 0; c < graph->nChannels; c++) {
sprintf(line, "%2d :", c);
int offset = strlen(line);
if(system->nodes[NET].count > 0 && system->nodes[GPU].count != system->nRanks && !graph->nIntraChannels) {
sprintf(line + offset, " %s/%d", topoNodeTypeStr[NET], graph->inter[2 * c]);
offset = strlen(line);
}
for(int i = 0; i < ngpus; i++) {
int n = graph->intraNets[(ngpus * c + i) * 2] - 'N';
if(n >= 0 && n < system->nodes[NET].count) {
sprintf(line + offset, " NET/%d", n);
offset = strlen(line);
}
sprintf(line + offset, " %s/%d", topoNodeTypeStr[GPU], graph->intra[ngpus * c + i]);
offset = strlen(line);
n = graph->intraNets[(ngpus * c + i) * 2 + 1] - 'N';
if(n >= 0 && n < system->nodes[NET].count) {
sprintf(line + offset, " NET/%d", n);
offset = strlen(line);
}
}
if(system->nodes[NET].count > 0 && system->nodes[GPU].count != system->nRanks && !graph->nIntraChannels) {
sprintf(line + offset, " %s/%d", topoNodeTypeStr[NET], graph->inter[2 * c + 1]);
offset = strlen(line);
}
INFO(SCCL_GRAPH, "%s", line);
}
return scclSuccess;
}
scclResult_t scclTopoDumpGraphs(struct scclTopoSystem* system, int ngraphs, struct scclTopoGraph** graphs) {
char* str = getenv("SCCL_GRAPH_DUMP_FILE");
if(str) {
INFO(SCCL_ENV, "SCCL_GRAPH_DUMP_FILE set by environment to %s", str);
struct scclXml* xml;
SCCLCHECK(scclCalloc(&xml, 1));
SCCLCHECK(scclTopoGetXmlFromGraphs(ngraphs, graphs, system, xml));
SCCLCHECK(scclTopoDumpXmlToFile(str, xml));
free(xml);
}
return scclSuccess;
}
#include "comm.h"
// NVLS channels aren't compute channels. Find which NIC corresponds to our rank being the head
scclResult_t getNvlsNetDev(struct scclComm* comm, struct scclTopoGraph* graph, int* dev) {
int localRanks = comm->topo->nodes[GPU].count;
for(int c = 0; c < graph->nChannels; c++) {
if(graph->intra[c * localRanks] == comm->rank) {
*dev = graph->inter[c * 2];
return scclSuccess;
}
}
WARN("Could not find NIC for rank %d in NVLS graph\n", comm->rank);
return scclInternalError;
}
// 0: don't use PXN for P2P, 1: use PXN if needed, 2: use PXN as much as possible to maximize aggregation
SCCL_PARAM(P2pPxnLevel, "P2P_PXN_LEVEL", 2);
scclResult_t scclTopoGetNetDev(struct scclComm* comm, int rank, struct scclTopoGraph* graph, int channelId, int peerRank, int* dev, int* proxyRank) {
if(graph) {
// Honor the net device in the graph
int channel = channelId % graph->nChannels;
int ngpus = comm->topo->nodes[GPU].count;
int index = graph->intra[channel * ngpus] == rank ? 0 : 1;
if(graph->pattern != SCCL_TOPO_PATTERN_NVLS) {
*dev = graph->inter[channel * 2 + index];
} else {
SCCLCHECK(getNvlsNetDev(comm, graph, dev));
}
SCCLCHECK(scclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank));
} else if(peerRank == -1) {
return scclInternalError;
} else {
// Start with our local NIC and local Rank
SCCLCHECK(scclTopoGetLocalNet(comm->topo, rank, channelId, dev));
*proxyRank = rank;
int pxnLevel = scclPxnDisable(comm) == 1 ? 0 : scclParamP2pPxnLevel();
// See whether we can use the remote rank preferred device.
if(scclParamCrossNic() == 0 || (pxnLevel != 0)) {
// Find local NIC number close to local cudaDev
int cudaDev = comm->peerInfo[peerRank].cudaDev;
int localRank;
if(scclTopoDevToRank(comm->topo, cudaDev, &localRank) != scclSuccess)
return scclSuccess;
int netDev;
SCCLCHECK(scclTopoGetLocalNet(comm->topo, localRank, channelId, &netDev));
int n;
// Check that device exists on our node
if(scclParamCrossNic() == 0) {
if(scclTopoIdToIndex(comm->topo, NET, netDev, &n) != scclSuccess) {
WARN("Rank %d requires NIC %d but that NIC is not available for rank %d", peerRank, netDev, rank);
return scclInvalidUsage;
}
*dev = netDev;
}
if(pxnLevel == 1) {
int g, n;
SCCLCHECK(scclTopoRankToIndex(comm->topo, rank, &g));
SCCLCHECK(scclTopoIdToIndex(comm->topo, NET, netDev, &n));
struct scclTopoNode* gpu = comm->topo->nodes[GPU].nodes + g;
if(gpu->paths[NET][n].type <= PATH_PXN) {
*dev = netDev;
SCCLCHECK(scclTopoGetIntermediateRank(comm->topo, rank, *dev, proxyRank));
}
} else if(pxnLevel == 2) {
// Check which local GPU corresponds to that NIC and see if we can use PXN.
int n, g1, g2;
SCCLCHECK(scclTopoIdToIndex(comm->topo, NET, netDev, &n));
SCCLCHECK(scclTopoRankToIndex(comm->topo, rank, &g1));
SCCLCHECK(scclTopoGetLocalGpu(comm->topo, netDev, &g2));
if(g2 != -1) {
struct scclTopoNode* peerGpu = comm->topo->nodes[GPU].nodes + g2;
if(peerGpu->paths[GPU][g1].type <= PATH_NVL && peerGpu->paths[NET][n].type <= PATH_PXB) {
*proxyRank = peerGpu->gpu.rank;
*dev = netDev;
return scclSuccess;
}
}
}
}
}
return scclSuccess;
}
scclResult_t scclTopoGetIntraNetDev(struct scclTopoSystem* system, int rank, struct scclTopoGraph* graph, int channelId, int type, int* dev) {
*dev = -1;
if(graph && graph->nIntraChannels) {
int n1 = -1;
int ngpus = system->nodes[GPU].count;
int nnets = system->nodes[NET].count;
int chan = channelId % graph->nIntraChannels;
for(int i = 0; i < ngpus; i++) {
if(graph->intra[ngpus * chan + i] == rank) {
n1 = graph->intraNets[(ngpus * chan + i) * 2 + type] - 'N';
break;
}
}
if(n1 >= 0 && n1 < nnets) {
*dev = n1;
}
}
return scclSuccess;
}
scclResult_t scclTopoGetLinkType(struct scclTopoSystem* system, int cudaDev1, int cudaDev2, bool* isXGMI, int maxInter, int nInter, int* inter) {
int interGpus[MAX_XGMI_INTER_GPUS + 1];
int ngpus = system->nodes[GPU].count;
*isXGMI = false;
// check for direct XGMI connection
for(int i = 0; i < ngpus; i++) {
if(system->nodes[GPU].nodes[i].gpu.dev == cudaDev1) {
struct scclTopoNode* node = system->nodes[GPU].nodes + i;
for(int k = 0; k < system->nodes[GPU].count; k++) {
if(node->paths[GPU][k].count == 1) {
struct scclTopoLink* link = node->paths[GPU][k].list[0];
struct scclTopoNode* remNode = link->remNode;
if(remNode->gpu.dev == cudaDev2) {
*isXGMI = (link->type == LINK_NVL);
if(*isXGMI)
return scclSuccess;
}
}
}
}
}
// try intermediate GPUs
if(maxInter) {
// check if there are intermediate GPUs that are connected to both
bool res1, res2, res3;
int j;
for(j = 0; j < nInter; j++) {
scclTopoGetLinkType(system, inter[j], inter[j + 1], &res1, 0);
if(!res1)
break;
}
if(j < nInter)
return scclSuccess;
if(nInter > 0 && inter != nullptr) {
scclTopoGetLinkType(system, inter[nInter], cudaDev2, &res2, 0);
if(res2) {
*isXGMI = true;
return scclSuccess;
}
memcpy(interGpus + 1, inter + 1, sizeof(int) * nInter);
}
interGpus[0] = cudaDev1;
// add one more intermediate GPU recursively util reaching max depth
nInter++;
if(nInter + 2 > ngpus || nInter > MAX_XGMI_INTER_GPUS || nInter > maxInter)
return scclSuccess;
for(int i = 0; i < ngpus; i++) {
int dev = system->nodes[GPU].nodes[i].gpu.dev;
// skip duplicated GPU
if(dev == cudaDev2)
continue;
for(j = 0; j < nInter; j++)
if(dev == interGpus[j])
break;
if(j < nInter)
continue;
// check connectivity with intermediate GPUs
interGpus[nInter] = dev;
scclTopoGetLinkType(system, cudaDev1, cudaDev2, &res3, maxInter, nInter, interGpus);
if(res3) {
*isXGMI = true;
return scclSuccess;
}
}
}
return scclSuccess;
}
} // namespace detect
} // namespace topology
} // namespace hardware
} // namespace sccl
#include "sccl.h"
namespace sccl {
namespace hardware {
namespace topology {
namespace detect {
#define RANK_TO_INDEX(r) (rank > root ? rank - 1 : rank)
/* Btree which alternates leaves and nodes.
* Assumes root is 0, which conveniently builds a tree on powers of two,
* (because we have pow2-1 ranks) which lets us manipulate bits.
* Find first non-zero bit, then :
* Find the parent :
* xx01[0] -> xx10[0] (1,5,9 below) or xx00[0] if xx10[0] is out of bounds (13 below)
* xx11[0] -> xx10[0] (3,7,11 below)
* Find the children :
* xx10[0] -> xx01[0] (2,4,6,8,10,12) or -1 (1,3,5,7,9,11,13)
* xx10[0] -> xx11[0] (2,4,6,8,10) or xx101[0] (12) or xx1001[0] ... or -1 (1,3,5,7,9,11,13)
*
* Illustration :
* 0---------------8
* ______/ \______
* 4 12
* / \ / \
* 2 6 10 \
* / \ / \ / \ \
* 1 3 5 7 9 11 13
*/
scclResult_t scclGetBtree(int nranks, int rank, int* u, int* d0, int* d1, int* parentChildType) {
int up, down0, down1;
int bit;
for(bit = 1; bit < nranks; bit <<= 1) {
if(bit & rank)
break;
}
if(rank == 0) {
*u = -1;
*d0 = -1;
// Child rank is > 0 so it has to be our child 1, not 0.
*d1 = nranks > 1 ? bit >> 1 : -1;
return scclSuccess;
}
up = (rank ^ bit) | (bit << 1);
// if smaller than the parent, we are his first child, otherwise we're his second
if(up >= nranks)
up = (rank ^ bit);
*parentChildType = (rank < up) ? 0 : 1;
*u = up;
int lowbit = bit >> 1;
// down0 is always within bounds
down0 = lowbit == 0 ? -1 : rank - lowbit;
down1 = lowbit == 0 ? -1 : rank + lowbit;
// Make sure down1 is within bounds
while(down1 >= nranks) {
down1 = lowbit == 0 ? -1 : rank + lowbit;
lowbit >>= 1;
}
*d0 = down0;
*d1 = down1;
return scclSuccess;
}
/* Build a double binary tree. Take the previous tree for the first tree.
* For the second tree, we use a mirror tree (if nranks is even)
*
* 0---------------8 3----------------11
* ______/ \ / \______
* 4 \ / 7
* / \ \ / / \
* 2 6 10 1 5 9
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 0 2 4 6 8 10
*
* or shift it by one rank (if nranks is odd).
*
* 0---------------8 1---------------9
* ______/ \______ ______/ \______
* 4 12 5 0
* / \ / / \ /
* 2 6 10 3 7 11
* / \ / \ / \ / \ / \ / \
* 1 3 5 7 9 11 2 4 6 8 10 12
*/
scclResult_t scclGetDtree(int nranks, int rank, int* s0, int* d0_0, int* d0_1, int* parentChildType0, int* s1, int* d1_0, int* d1_1, int* parentChildType1) {
// First tree ... use a btree
scclGetBtree(nranks, rank, s0, d0_0, d0_1, parentChildType0);
// Second tree ... mirror or shift
if(nranks % 2 == 1) {
// shift
int shiftrank = (rank - 1 + nranks) % nranks;
int u, d0, d1;
scclGetBtree(nranks, shiftrank, &u, &d0, &d1, parentChildType1);
*s1 = u == -1 ? -1 : (u + 1) % nranks;
*d1_0 = d0 == -1 ? -1 : (d0 + 1) % nranks;
*d1_1 = d1 == -1 ? -1 : (d1 + 1) % nranks;
} else {
// mirror
int u, d0, d1;
scclGetBtree(nranks, nranks - 1 - rank, &u, &d0, &d1, parentChildType1);
*s1 = u == -1 ? -1 : nranks - 1 - u;
*d1_0 = d0 == -1 ? -1 : nranks - 1 - d0;
*d1_1 = d1 == -1 ? -1 : nranks - 1 - d1;
}
return scclSuccess;
}
} // namespace detect
} // namespace topology
} // namespace hardware
} // namespace sccl
#include "core.h"
#include "devcomm.h"
#include "comm.h"
#include "topo.h"
namespace sccl {
namespace hardware {
namespace topology {
namespace detect {
SCCL_PARAM(Nthreads, "NTHREADS", -2);
SCCL_PARAM(Ll128Nthreads, "LL128_NTHREADS", -2);
static int getNthreads(const char* name, int env, int min, int max, int def, int WarpSize) {
int nt = env;
if(nt > 0) {
if(nt % WarpSize != 0) {
WARN("Invalid %s %d (must be a multiple of %d)", name, nt, WarpSize);
nt = max;
} else if(nt > max) {
WARN("Invalid %s %d (maximum %d).", name, nt, max);
nt = max;
} else if(nt < min) {
WARN("Invalid %s %d (minimum %d).", name, nt, min);
nt = min;
}
} else {
nt = def;
}
return nt;
}
scclResult_t parseList(const char* str, const char* elems[], int nelems, int* list) {
int def, set;
if(str[0] == '^') {
def = 1;
set = 0;
str++;
} else {
def = 0;
set = 1;
}
for(int i = 0; i < nelems; i++)
list[i] = def;
char* tokStr = strdup(str);
char* tmpStr;
char* token = strtok_r(tokStr, ",", &tmpStr);
while(token) {
for(int i = 0; i < nelems; i++)
if(strcasecmp(token, elems[i]) == 0)
list[i] = set;
token = strtok_r(NULL, ",", &tmpStr);
}
free(tokStr);
return scclSuccess;
}
// Latencies in us, Bandwidths in GB/s
// Tree { LL, LL128, Simple } , Ring { LL, LL128, Simple }
static const float baseLat[SCCL_NUM_ALGORITHMS][SCCL_NUM_PROTOCOLS] = {{12.0, 12.0, 17.0},
{12.0, 12.0, 17.0}, // Tree, Ring
{12.0, 12.0, 17.0},
{12.0, 12.0, 17.0}, // Collnet Direct, Chain
{0, 0, 0},
{0, 0, 0}}; // NVLS, NVLS Tree
// NVLink, PCI, Network
#define SCCL_HW_NVLINK 0
#define SCCL_HW_PCI 1
#define SCCL_HW_NET 2
struct tuningModel {
float hwLat[3][SCCL_NUM_ALGORITHMS][SCCL_NUM_PROTOCOLS];
float bwRatio[2][SCCL_NUM_ALGORITHMS][SCCL_NUM_PROTOCOLS];
float treeCorrectionFactor[SCCL_NUM_PROTOCOLS][27];
float ringCorrectionFactor[SCCL_NUM_PROTOCOLS][27];
};
static struct tuningModel tuning_model_0{
.hwLat =
{
/* NVLINK */
{/* Tree (LL/LL128/Simple)*/ {0.8, 1.4, 2.5},
/* Ring (LL/LL128/Simple)*/ {0.8, 2.2, 3.6},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 0.8},
/* CollNetChain (Simple)*/ {0.0, 0.0, 1.4},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* PCI */
{/* Tree (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* Ring (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 5.7},
/* CollNetChain (Simple)*/ {0.0, 0.0, 5.7},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* NET */
{/* Tree (LL/LL128/Simple)*/ {11.8, 18.2, 20.8},
/* Ring (LL/LL128/Simple)*/ {9.5, 19.8, 15.1},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 11.8},
/* CollNetChain (Simple)*/ {0.0, 0.0, 18.2},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.bwRatio =
{
/* 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.28, 0.22, 0.91},
/* Ring (LL/LL128/Simple)*/ {0.31, 0.34, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* more than 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.04, 0.22, 0.95},
/* Ring (LL/LL128/Simple)*/ {0.04, 0.34, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.treeCorrectionFactor =
{
{
0.1, 0.2, 0.1, 0.1, 0.9, 0.3, 0.4, 0.1, 0.2, 0.4, 0.2, 0.1, 0.3, 0.3, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
},
{
0.1, 0.3, 1.0, 0.1, 0.5, 1.0, 0.9, 1.0, 1.0, 1.0, 0.3, 0.1, 0.4, 0.5, 0.5, 0.4, 0.4, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
},
// { 0.2, 1.0, 0.1, 0.1, 0.7, 0.2, 0.4, 0.1, 0.1, 0.3, 0.4, 0.3, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9, },
{
0.2, 1.0, 0.1, 0.1, 0.7, 0.2, 0.4, 0.1, 0.1, 0.3, 0.4, 0.3, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 0.9, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
},
},
.ringCorrectionFactor =
{
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.4, 0.2, 0.3, 0.5, 0.3, 0.1, 0.5, 0.5, 0.3, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
},
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.7, 0.5, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3,
},
{
1.0, 0.8, 0.2, 1.0, 1.0, 0.3, 1.0, 0.1, 0.1, 0.2, 0.2, 0.1, 0.5, 1.0, 0.8, 0.8, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
},
},
};
static struct tuningModel tuning_model_1{
.hwLat =
{
/* NVLINK */
{/* Tree (LL/LL128/Simple)*/ {1.5, 1.5, 4.5},
/* Ring (LL/LL128/Simple)*/ {1.5, 1.5, 4.5},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 4.5},
/* CollNetChain (Simple)*/ {0.0, 0.0, 4.5},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* PCI */
{/* Tree (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* Ring (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 5.7},
/* CollNetChain (Simple)*/ {0.0, 0.0, 5.7},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* NET */
{/* Tree (LL/LL128/Simple)*/ {33.0, 33.0, 15.8},
/* Ring (LL/LL128/Simple)*/ {5.1, 5.1, 68.8},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 15.8},
/* CollNetChain (Simple)*/ {0.0, 0.0, 15.8},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.bwRatio =
{
/* 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.30, 1.00, 0.99},
/* Ring (LL/LL128/Simple)*/ {0.31, 1.00, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* more than 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.15, 1.00, 0.42},
/* Ring (LL/LL128/Simple)*/ {0.20, 1.00, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.treeCorrectionFactor =
{
{
0.5, 0.4, 0.7, 0.6, 1.0, 1.0, 0.5, 0.4, 0.1, 0.5, 0.4, 0.6, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.6, 0.5, 0.4, 0.4, 0.3, 0.2, 0.1, 0.1, 0.1,
},
{
0.5, 0.4, 0.7, 0.6, 1.0, 1.0, 0.5, 0.4, 0.1, 0.5, 0.4, 0.6, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.6, 0.5, 0.4, 0.4, 0.3, 0.2, 0.1, 0.1, 0.1,
},
// { 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.4, 0.5, 0.1, 0.6, 1.0, 1.0, 1.0, 0.6, 0.5, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.5, 0.3, 0.3, },
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.4, 0.5, 0.1, 0.6, 1.0, 1.0, 1.0, 0.6, 0.5, 0.7, 1.0, 1.0, 1.0, 0.4, 0.4, 0.4, 0.4, 0.3, 0.2, 0.1, 0.1,
},
},
.ringCorrectionFactor =
{
{
1.0, 0.5, 1.0, 1.0, 0.6, 0.7, 1.0, 1.0, 0.2, 1.0, 0.9, 0.7, 1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.8, 0.7, 0.6, 0.5, 0.5, 0.3, 0.2, 0.1, 0.1,
},
{
1.0, 0.5, 1.0, 1.0, 0.6, 0.7, 1.0, 1.0, 0.2, 1.0, 0.9, 0.7, 1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.8, 0.7, 0.6, 0.5, 0.5, 0.3, 0.2, 0.1, 0.1,
},
{
0.3, 1.0, 0.3, 0.1, 0.1, 0.1, 0.3, 0.7, 1.0, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.5, 0.9, 1.0, 1.0, 1.0, 1.0,
},
},
};
static struct tuningModel tuning_model_2{
.hwLat =
{
/* NVLINK */
{/* Tree (LL/LL128/Simple)*/ {1.5, 1.5, 4.5},
/* Ring (LL/LL128/Simple)*/ {1.5, 1.5, 4.5},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 4.5},
/* CollNetChain (Simple)*/ {0.0, 0.0, 4.5},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* PCI */
{/* Tree (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* Ring (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 5.7},
/* CollNetChain (Simple)*/ {0.0, 0.0, 5.7},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* NET */
{/* Tree (LL/LL128/Simple)*/ {27.9, 27.9, 15.8},
/* Ring (LL/LL128/Simple)*/ {12.1, 12.1, 68.8},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 15.8},
/* CollNetChain (Simple)*/ {0.0, 0.0, 15.8},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.bwRatio =
{
/* 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.30, 1.00, 0.99},
/* Ring (LL/LL128/Simple)*/ {0.31, 1.00, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* more than 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.07, 1.00, 0.42},
/* Ring (LL/LL128/Simple)*/ {0.08, 1.00, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.treeCorrectionFactor =
{
{
0.1, 0.4, 0.3, 0.3, 0.2, 0.4, 0.5, 0.1, 0.1, 0.6, 0.7, 0.7, 0.8, 1.0, 0.9, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
},
{
0.1, 0.4, 0.3, 0.3, 0.2, 0.4, 0.5, 0.1, 0.1, 0.6, 0.7, 0.7, 0.8, 1.0, 0.9, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
},
// { 1.0, 0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.5, 0.1, 0.6, 0.9, 0.8, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.9, 0.9, 1.0, 1.0, 1.0, },
{
1.0, 0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.5, 0.1, 0.6, 0.9, 0.8, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 0.4, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4,
},
},
.ringCorrectionFactor =
{
{
0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 1.0, 1.0, 1.0, 1.0, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
},
{
0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 1.0, 1.0, 1.0, 1.0, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
},
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.4, 0.5, 0.6, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
},
},
};
static struct tuningModel tuning_model_3{
.hwLat =
{
/* NVLINK */
{/* Tree (LL/LL128/Simple)*/ {0.8, 0.0, 2.5},
/* Ring (LL/LL128/Simple)*/ {0.8, 0.0, 3.6},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 0.8},
/* CollNetChain (Simple)*/ {0.0, 0.0, 0.0},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* PCI */
{/* Tree (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* Ring (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 5.7},
/* CollNetChain (Simple)*/ {0.0, 0.0, 5.7},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* NET */
{/* Tree (LL/LL128/Simple)*/ {12.5, 0.0, 22.4},
/* Ring (LL/LL128/Simple)*/ {9.5, 0.0, 19.8},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 12.5},
/* CollNetChain (Simple)*/ {0.0, 0.0, 0.0},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.bwRatio =
{
/* 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.20, 0.00, 1.75},
/* Ring (LL/LL128/Simple)*/ {0.20, 0.00, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* more than 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.20, 0.00, 0.96},
/* Ring (LL/LL128/Simple)*/ {0.20, 0.00, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.treeCorrectionFactor =
{
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 0.2, 1.0, 0.9, 1.0, 0.6, 0.4, 0.6, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
},
{
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
},
// { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.2, 1.0, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.7, 0.8, 0.9, 0.7, 0.7, },
{
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.2, 1.0, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 0.4, 0.3, 0.3, 0.3, 0.4, 0.3, 0.3,
},
},
.ringCorrectionFactor =
{
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.1, 0.2, 0.1, 0.4, 0.4, 0.2, 0.2, 0.3, 0.7, 0.5, 0.4, 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
},
{
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
},
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.5, 1.0, 0.1, 0.3, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.4, 0.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
},
},
};
static struct tuningModel tuning_model_4{
.hwLat =
{
/* NVLINK */
{/* Tree (LL/LL128/Simple)*/ {0.8, 1.4, 2.5},
/* Ring (LL/LL128/Simple)*/ {0.8, 2.2, 3.6},
/* CollNetDirect (Simple)*/ {0.8, 1.4, 2.5},
/* CollNetChain (Simple)*/ {0.8, 1.4, 2.5},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* PCI */
{/* Tree (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* Ring (LL/LL128/Simple)*/ {2.2, 2.2, 5.7},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 5.7},
/* CollNetChain (Simple)*/ {0.0, 0.0, 5.7},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* NET */
{/* Tree (LL/LL128/Simple)*/ {32.2, 34.4, 47.6},
/* Ring (LL/LL128/Simple)*/ {35.4, 87.8, 209.2},
/* CollNetDirect (Simple)*/ {0.0, 0.0, 47.6},
/* CollNetChain (Simple)*/ {0.0, 0.0, 47.6},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.bwRatio =
{
/* 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.16, 1.09, 1.61},
/* Ring (LL/LL128/Simple)*/ {0.15, 0.41, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
/* more than 2 nodes */
{/* Tree (LL/LL128/Simple)*/ {0.16, 1.09, 1.08},
/* Ring (LL/LL128/Simple)*/ {0.15, 0.41, 1.00},
/* CollNetDirect (Simple)*/ {0.00, 0.00, 1.00},
/* CollNetChain (Simple)*/ {0.00, 0.00, 1.00},
/* NVLS */ {0, 0, 0},
/* NVLS Tree */ {0, 0, 0}},
},
.treeCorrectionFactor =
{
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.1, 0.1, 0.2, 0.4, 0.6, 0.5, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
},
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 1.0, 0.5, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
},
// { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.4, 0.3, 0.3, 0.1, 0.1, 1.0, 1.0, 0.7, 0.5, 0.6, 0.5, 0.6, 0.6, 0.5, 0.6, 0.6, 0.6, 0.7, },
// { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.4, 0.3, 0.3, 0.1, 0.1, 1.0, 1.0, 0.7, 0.5, 0.6, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, },
},
.ringCorrectionFactor =
{
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.3, 0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
},
{
0.4, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.2, 0.2, 0.1, 0.3, 1.0, 1.0, 0.7, 0.8, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.8, 0.5, 0.4, 0.3, 0.3,
},
{
0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 0.8, 0.5, 0.1, 0.7, 0.2, 0.4, 0.4, 0.6, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
},
},
};
static struct tuningModel rcclTuningModel[] = {
tuning_model_0,
tuning_model_1,
tuning_model_2,
tuning_model_3,
tuning_model_4,
};
/* Array indexes used below */
#define VOLTA_COMPCAP_IDX 0
#define AMPERE_COMPCAP_IDX 1
#define HOPPER_COMPCAP_IDX 2
// LL128 max BW per channel
static const double llMaxBws[3][3] = {
/* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4},
/* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0},
/* Hopper-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0}};
static const double perChMaxRingLL128Bws[3][3] = {
/* Volta (N1/N2/N4) */ {20.0, 20.0, 20.0},
/* Ampere (N1/N2/N4) */ {20.0, 20.0, 20.0},
/* Hopper (N1/N2/N4) */ {36.7, 36.7, 36.7},
};
static const double perChMaxTreeLL128Bws[3][3] = {
/* Volta (N1/N2/N4) */ {20.0, 20.0, 20.0},
/* Ampere (N1/N2/N4) */ {20.0, 20.0, 20.0},
/* Hopper (N1/N2/N4) */ {36.7, 36.7, 29.0},
};
static const double perChMaxTreeBws[3][3] = {
/* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0},
/* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8},
/* Hopper (N1/N2/N4) */ {38.7, 41.4, 36.0},
};
// Network post overhead in ns (1000 = 1 us)
SCCL_PARAM(NetOverhead, "NET_OVERHEAD", -2);
static float getNetOverhead(struct scclComm* comm) {
if(scclParamNetOverhead() != -2)
return scclParamNetOverhead() * .001;
int cpuArch, cpuVendor, cpuModel;
SCCLCHECK(scclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel));
if(cpuArch == SCCL_TOPO_CPU_ARCH_X86 && cpuVendor == SCCL_TOPO_CPU_VENDOR_INTEL)
return 1.0;
if(cpuArch == SCCL_TOPO_CPU_ARCH_X86 && cpuVendor == SCCL_TOPO_CPU_VENDOR_AMD)
return 2.0;
else
return 1.0;
}
scclResult_t scclTopoTuneModel(struct scclComm* comm, int minCompCap, int maxCompCap, struct scclTopoGraph** graphs) {
int simpleDefaultThreads = (graphs[SCCL_ALGO_RING]->bwIntra * graphs[SCCL_ALGO_RING]->nChannels <= PCI_BW) ? 256 : SCCL_SIMPLE_MAX_NTHREADS;
comm->maxThreads[SCCL_ALGO_RING][SCCL_PROTO_SIMPLE] =
getNthreads("SCCL_NTHREADS", scclParamNthreads(), 4 * comm->WarpSize, SCCL_MAX_NTHREADS, simpleDefaultThreads, comm->WarpSize);
comm->maxThreads[SCCL_ALGO_TREE][SCCL_PROTO_SIMPLE] = comm->maxThreads[SCCL_ALGO_COLLNET_DIRECT][SCCL_PROTO_SIMPLE] =
getNthreads("SCCL_NTHREADS", scclParamNthreads(), 4 * comm->WarpSize, SCCL_MAX_NTHREADS, SCCL_MAX_NTHREADS, comm->WarpSize);
comm->maxThreads[SCCL_ALGO_RING][SCCL_PROTO_LL] = comm->maxThreads[SCCL_ALGO_TREE][SCCL_PROTO_LL] =
comm->maxThreads[SCCL_ALGO_COLLNET_DIRECT][SCCL_PROTO_LL] =
getNthreads("SCCL_NTHREADS", scclParamNthreads(), 4 * comm->WarpSize, SCCL_MAX_NTHREADS, SCCL_MAX_NTHREADS, comm->WarpSize);
comm->maxThreads[SCCL_ALGO_RING][SCCL_PROTO_LL128] = comm->maxThreads[SCCL_ALGO_TREE][SCCL_PROTO_LL128] =
getNthreads("SCCL_LL128_NTHREADS", scclParamLl128Nthreads(), 4 * comm->WarpSize, SCCL_LL128_MAX_NTHREADS, SCCL_LL128_MAX_NTHREADS, comm->WarpSize);
int nNodes = comm->nNodes;
int nRanks = comm->nRanks;
if(nRanks <= 1)
return scclSuccess;
int compCapIndex = minCompCap >= 90 ? HOPPER_COMPCAP_IDX : minCompCap >= 80 ? AMPERE_COMPCAP_IDX : VOLTA_COMPCAP_IDX;
int cpuArch, cpuVendor, cpuModel;
SCCLCHECK(scclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel));
int index2 = nNodes <= 2 ? nNodes - 1 : 2;
// LL: for single node, we look at GPU type; for multi-node, we look at CPU type
int index1 = nNodes == 1 ? compCapIndex : cpuVendor == SCCL_TOPO_CPU_VENDOR_AMD ? 1 : 0;
double llMaxBw = llMaxBws[index1][index2];
double perChMaxTreeBw = perChMaxTreeBws[compCapIndex][index2];
double perChMaxRingLL128Bw = perChMaxRingLL128Bws[compCapIndex][index2];
double perChMaxTreeLL128Bw = perChMaxTreeLL128Bws[compCapIndex][index2];
// De-penalize Tree/Simple latency on Power systems to favor Tree than Ring
// if (cpuArch == SCCL_TOPO_CPU_ARCH_POWER) hwLat[SCCL_HW_PCI][SCCL_ALGO_TREE][SCCL_PROTO_SIMPLE] = hwLat[SCCL_HW_PCI][SCCL_ALGO_RING][SCCL_PROTO_SIMPLE];
float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount
int intraHw[SCCL_NUM_ALGORITHMS], hw[SCCL_NUM_ALGORITHMS];
for(int a = 0; a < SCCL_NUM_ALGORITHMS; a++)
intraHw[a] = graphs[a]->typeIntra == LINK_NVL ? SCCL_HW_NVLINK : SCCL_HW_PCI;
for(int a = 0; a < SCCL_NUM_ALGORITHMS; a++)
hw[a] = nNodes == 1 ? intraHw[a] : SCCL_HW_NET;
for(int coll = 0; coll < SCCL_NUM_FUNCTIONS; coll++) {
int nsteps = coll == scclFuncAllReduce ? 2 * (nRanks - 1) : coll == scclFuncReduceScatter || coll == scclFuncAllGather ? nRanks - 1 : nRanks;
int nInterSteps = coll == scclFuncAllReduce ? (nNodes > 1 ? 2 * nNodes : 0)
: coll == scclFuncReduceScatter || coll == scclFuncAllGather ? nNodes - 1
: nNodes;
for(int a = 0; a < SCCL_NUM_ALGORITHMS; a++) {
if(coll == scclFuncBroadcast && a != SCCL_ALGO_RING)
continue;
if(coll == scclFuncReduce && a != SCCL_ALGO_RING)
continue;
if(coll == scclFuncReduceScatter && a != SCCL_ALGO_RING)
continue;
if(coll == scclFuncAllGather && a != SCCL_ALGO_RING)
continue;
for(int p = 0; p < SCCL_NUM_PROTOCOLS; p++) {
if((a == SCCL_ALGO_NVLS || a == SCCL_ALGO_NVLS_TREE) && p != SCCL_PROTO_SIMPLE)
continue;
int collnet = (a == SCCL_ALGO_COLLNET_DIRECT || a == SCCL_ALGO_COLLNET_CHAIN) ? 1 : 0;
float bw = nNodes <= 2 || collnet ? graphs[a]->bwIntra : graphs[a]->bwInter;
float busBw = comm->topo->baseBw != 0.0 ? comm->topo->baseBw : graphs[a]->nChannels * bw;
// INFO(SCCL_INIT, "algo %s proto %s busBw %f baseBw %f bw %f nChannels %d bwIntra %f bwInter %f", scclAlgoStr[a], scclProtoStr[p], busBw,
// comm->topo->baseBw, bw, graphs[a]->nChannels, graphs[a]->bwIntra, graphs[a]->bwInter);
// Various model refinements
if(nNodes <= 2)
busBw *= rcclTuningModel[comm->topo->tuning].bwRatio[0][a][p];
else
busBw *= rcclTuningModel[comm->topo->tuning].bwRatio[1][a][p];
if(a == SCCL_ALGO_COLLNET_DIRECT && p == SCCL_PROTO_SIMPLE && minCompCap >= 90)
busBw *= .85;
// Convert bus BW to algorithm BW
float ratio;
if(a == SCCL_ALGO_RING)
ratio = (1.0 * nRanks) / nsteps;
else if(a == SCCL_ALGO_NVLS)
ratio = 5.0 / 6.0;
else if(a == SCCL_ALGO_NVLS_TREE)
ratio = .70 * nNodes / (2 * (nNodes - 1));
else
ratio = .5;
comm->bandwidths[coll][a][p] = busBw * ratio;
comm->latencies[coll][a][p] = baseLat[a][p];
float intraLat = rcclTuningModel[comm->topo->tuning].hwLat[intraHw[a]][a][p];
float interLat = graphs[a]->latencyInter ? graphs[a]->latencyInter : rcclTuningModel[comm->topo->tuning].hwLat[SCCL_HW_NET][a][p];
// if (nNodes > 1 && p == SCCL_PROTO_LL) intraLat *= 1.8;
if(p == SCCL_PROTO_SIMPLE)
interLat += graphs[a]->latencyInter;
if(a == SCCL_ALGO_RING) {
float lat = rcclTuningModel[comm->topo->tuning].hwLat[hw[a]][a][p];
if((coll == scclFuncReduce || coll == scclFuncBroadcast)) {
if(graphs[a]->sameChannels) {
comm->latencies[coll][a][p] += lat;
} else {
if(p == SCCL_PROTO_SIMPLE)
lat = rcclTuningModel[comm->topo->tuning]
.hwLat[hw[a]][SCCL_ALGO_TREE][p]; // Add some chunk latency, waiting for proper chunk modeling
comm->latencies[coll][a][p] += nsteps * lat;
}
} else {
// Inter-node rings still have to launch nsteps * net overhead.
float netOverhead = 0.0;
if(nNodes > 1) {
netOverhead = getNetOverhead(comm);
if(p == SCCL_PROTO_SIMPLE)
netOverhead *= 3;
}
intraLat = std::max(intraLat, netOverhead);
comm->latencies[coll][a][p] += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat;
}
} else if(a == SCCL_ALGO_TREE) {
comm->latencies[coll][a][p] += 2 * ((nRanks / nNodes - 1) * intraLat + log2i(nNodes) * interLat);
} else if(a == SCCL_ALGO_COLLNET_DIRECT) {
comm->latencies[coll][a][p] +=
2 * (std::min(1, (nRanks / nNodes - 1)) * intraLat + (nRanks / nNodes - 1) * 0.5) + interLat; // Add 0.5 arity serialization latency
} else if(a == SCCL_ALGO_COLLNET_CHAIN) {
comm->latencies[coll][a][p] += 2 * (nRanks / nNodes - 1) * intraLat + interLat;
} else if(a == SCCL_ALGO_NVLS) {
if(nNodes > 1)
comm->latencies[coll][a][p] += rcclTuningModel[comm->topo->tuning].hwLat[SCCL_HW_NET][a][p];
} else if(a == SCCL_ALGO_NVLS_TREE) {
comm->latencies[coll][a][p] += 2 * (nNodes - 1) * rcclTuningModel[comm->topo->tuning].hwLat[SCCL_HW_NET][a][p];
}
}
}
}
// Protocols/Algorithms enable/disable, and user overrides.
// All are enabled except ll128 which is enabled by default only in certain cases.
int protoEnable[SCCL_NUM_PROTOCOLS] = {1, 2, 1};
int algoEnable[SCCL_NUM_ALGORITHMS] = {1, 1, 1, 1, 1, 1};
const char* protoStr = getenv("SCCL_PROTO");
if(protoStr) {
INFO(SCCL_ENV, "SCCL_PROTO set by environment to %s", protoStr);
SCCLCHECK(parseList(protoStr, scclProtoStr, SCCL_NUM_PROTOCOLS, protoEnable));
}
const char* algoStr = getenv("SCCL_ALGO");
if(algoStr) {
INFO(SCCL_ENV, "SCCL_ALGO set by environment to %s", algoStr);
SCCLCHECK(parseList(algoStr, scclAlgoStr, SCCL_NUM_ALGORITHMS, algoEnable));
}
if(comm->nNodes == 1)
algoEnable[SCCL_ALGO_NVLS_TREE] = 0;
// Disable CollNet if it is not supported
if(comm->collNetSupport == 0) {
algoEnable[SCCL_ALGO_COLLNET_DIRECT] = 0;
algoEnable[SCCL_ALGO_COLLNET_CHAIN] = 0;
if(comm->nNodes > 1)
algoEnable[SCCL_ALGO_NVLS] = 0;
// If user has hard set SCCL_ALGO=COLLNET, ignore it
if(algoEnable[SCCL_ALGO_RING] == 0 && algoEnable[SCCL_ALGO_TREE] == 0 && algoEnable[SCCL_ALGO_NVLS] == 0 && algoEnable[SCCL_ALGO_NVLS_TREE] == 0) {
algoEnable[SCCL_ALGO_RING] = algoEnable[SCCL_ALGO_TREE] = 1;
if(comm->rank == 0)
WARN("CollNet is not supported or fails to initialize, ignoring SCCL_ALGO=COLLNET");
}
} else {
// Disable CollNet+Direct if not on an NVSwitch system
int nvsCount = 0;
SCCLCHECK(scclTopoGetNvsCount(comm->topo, &nvsCount));
if(nvsCount == 0)
algoEnable[SCCL_ALGO_COLLNET_DIRECT] = 0;
}
for(int c = 0; c < SCCL_NUM_FUNCTIONS; c++)
for(int a = 0; a < SCCL_NUM_ALGORITHMS; a++)
for(int p = 0; p < SCCL_NUM_PROTOCOLS; p++) {
// Disable LL protocol on gfx11xx
int pEnable = protoEnable[p];
if(pEnable == 2 && p == SCCL_PROTO_LL128) {
#if defined(ENABLE_LL128)
// Enable LL128 by default only on gfx90a with available tuning table
pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL &&
(IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx90a") && comm->topo->ll128Enabled)
? 1
: 0;
#else
pEnable = 0;
#endif
}
if(pEnable == 0)
comm->bandwidths[c][a][p] = 0;
// Never disable ring for non-allreduce operations. That allows to run real apps with SCCL_ALGO=TREE.
if(a == SCCL_ALGO_RING && c != scclFuncAllReduce)
continue;
if(algoEnable[a] == 0)
comm->bandwidths[c][a][p] = 0;
}
if(comm->rank == 0) {
char line[1024];
for(int block = 0; block < 2; block++) {
sprintf(line, " Algorithm |");
for(int ba = 0; ba < SCCL_NUM_ALGORITHMS / 2; ba++) {
int a = block * SCCL_NUM_ALGORITHMS / 2 + ba;
sprintf(line + strlen(line), " %14s %14s %14s |", "", scclAlgoStr[a], "");
}
INFO(SCCL_TUNING, "%s", line);
sprintf(line, " Protocol |");
for(int ba = 0; ba < SCCL_NUM_ALGORITHMS / 2; ba++) {
for(int p = 0; p < SCCL_NUM_PROTOCOLS; p++) {
sprintf(line + strlen(line), " %14s |", scclProtoStr[p]);
}
}
INFO(SCCL_TUNING, "%s", line);
sprintf(line, " Max NThreads |");
for(int ba = 0; ba < SCCL_NUM_ALGORITHMS / 2; ba++) {
int a = block * SCCL_NUM_ALGORITHMS / 2 + ba;
for(int p = 0; p < SCCL_NUM_PROTOCOLS; p++) {
sprintf(line + strlen(line), " %14d |", comm->maxThreads[a][p]);
}
}
INFO(SCCL_TUNING, "%s", line);
for(int c = 0; c < SCCL_NUM_FUNCTIONS; c++) {
sprintf(line, "%13s |", scclFuncStr[c]);
for(int ba = 0; ba < SCCL_NUM_ALGORITHMS / 2; ba++) {
int a = block * SCCL_NUM_ALGORITHMS / 2 + ba;
for(int p = 0; p < SCCL_NUM_PROTOCOLS; p++) {
sprintf(line + strlen(line), "%8.1f/%6.1f |", comm->latencies[c][a][p], comm->bandwidths[c][a][p]);
}
}
INFO(SCCL_TUNING, "%s", line);
}
}
}
// Set per-thread amount of work before we increase nThreads and nChannels
for(int a = 0; a < SCCL_NUM_ALGORITHMS; a++) {
comm->threadThresholds[a][SCCL_PROTO_LL] = SCCL_LL_THREAD_THRESHOLD;
comm->threadThresholds[a][SCCL_PROTO_LL128] = SCCL_LL128_THREAD_THRESHOLD;
comm->threadThresholds[a][SCCL_PROTO_SIMPLE] = SCCL_SIMPLE_THREAD_THRESHOLD;
}
comm->threadThresholds[SCCL_ALGO_RING][SCCL_PROTO_LL] *= nRanks;
comm->threadThresholds[SCCL_ALGO_COLLNET_DIRECT][SCCL_PROTO_SIMPLE] = 256;
comm->threadThresholds[SCCL_ALGO_COLLNET_CHAIN][SCCL_PROTO_SIMPLE] = 256;
// Override defaults with user env
char* str = getenv("SCCL_THREAD_THRESHOLDS");
if(str) {
INFO(SCCL_ENV, "SCCL_THREAD_THRESHOLDS set by environment to %s", str);
ssize_t t[2][SCCL_NUM_PROTOCOLS] = {{-2, -2, -2}, {-2, -2, -2}};
sscanf(str, "%ld %ld %ld %ld %ld %ld", t[0], t[0] + 1, t[0] + 2, t[1], t[1] + 1, t[1] + 2);
for(int a = 0; a < 2; a++) {
for(int p = 0; p < SCCL_NUM_PROTOCOLS; p++) {
if(t[a][p] >= 0)
comm->threadThresholds[a][p] = t[a][p];
}
}
}
INFO(SCCL_INIT,
"threadThresholds %ld/%ld/%ld | %ld/%ld/%ld | %ld | %ld",
comm->threadThresholds[SCCL_ALGO_TREE][SCCL_PROTO_LL],
comm->threadThresholds[SCCL_ALGO_TREE][SCCL_PROTO_LL128],
comm->threadThresholds[SCCL_ALGO_TREE][SCCL_PROTO_SIMPLE],
comm->threadThresholds[SCCL_ALGO_RING][SCCL_PROTO_LL],
comm->threadThresholds[SCCL_ALGO_RING][SCCL_PROTO_LL128],
comm->threadThresholds[SCCL_ALGO_RING][SCCL_PROTO_SIMPLE],
comm->threadThresholds[SCCL_ALGO_COLLNET_DIRECT][SCCL_PROTO_SIMPLE],
comm->threadThresholds[SCCL_ALGO_COLLNET_CHAIN][SCCL_PROTO_SIMPLE]);
return scclSuccess;
}
scclResult_t scclTopoGetAlgoTime(struct scclInfo* info, int algorithm, int protocol, int numPipeOps, float* time) {
float bw = info->comm->bandwidths[info->coll][algorithm][protocol];
float lat = info->comm->latencies[info->coll][algorithm][protocol];
if(bw == 0) {
*time = -1.0;
return scclSuccess;
}
int logSize = log2i(info->nBytes >> 6);
if(algorithm == SCCL_ALGO_TREE) {
if(logSize < 27)
bw *= rcclTuningModel[info->comm->topo->tuning].treeCorrectionFactor[protocol][logSize];
else
bw *= rcclTuningModel[info->comm->topo->tuning].treeCorrectionFactor[protocol][26];
} else if(algorithm == SCCL_ALGO_RING && info->comm->nNodes > 1) {
if(logSize < 27)
bw *= rcclTuningModel[info->comm->topo->tuning].ringCorrectionFactor[protocol][logSize];
else
bw *= rcclTuningModel[info->comm->topo->tuning].ringCorrectionFactor[protocol][26];
}
// Tree pipelining saves latency in aggregation cases
int latCount = algorithm == SCCL_ALGO_RING ? numPipeOps : DIVUP(numPipeOps, SCCL_MAX_WORK_ELEMENTS);
*time = lat * latCount + (info->nBytes) / (1000 * bw);
return scclSuccess;
}
} // namespace detect
} // namespace topology
} // namespace hardware
} // namespace sccl
#pragma once
#include <stdint.h>
#include "base.h"
namespace sccl {
namespace hardware {} // namespace hardware
} // namespace sccl
#include <sys/types.h>
#include <unistd.h>
#include "ibvsymbols.h"
/* RDMA-core linking mode. Symbols are pointers to linked IB Verbs */
namespace sccl {
namespace hardware {
namespace net {
namespace device {
#define ASSIGN_SYM(container, symbol, name) container->name = &symbol;
// Passthrough function for ibv_reg_mr macro in verbs.h
/**
* 注册一个内存区域(MR)到保护域(PD)
*
* @param pd 保护域指针
* @param addr 内存起始地址
* @param length 内存区域长度
* @param access 访问权限标志
* @return 成功返回注册的MR指针,失败返回NULL
*/
struct ibv_mr* ibv_internal_reg_mr(struct ibv_pd* pd, void* addr, size_t length, int access) { return ibv_reg_mr(pd, addr, length, access); }
// Passthrough function for ibv_internal_query_port macro in verbs.h
/**
* 查询IB设备的端口属性
*
* @param context IB设备上下文
* @param port_num 端口号
* @param port_attr 用于存储端口属性的结构体指针
* @return 执行结果,0表示成功,非0表示失败
*
* @note 此函数是对ibv_query_port的简单封装
*/
int ibv_internal_query_port(struct ibv_context* context, uint8_t port_num, struct ibv_port_attr* port_attr) {
return ibv_query_port(context, port_num, port_attr);
}
/**
* @brief 构建IB Verbs符号表
*
* 该函数用于初始化IB Verbs库的函数指针,将内部实现的函数地址赋值给符号表结构体。
* 每个IB Verbs API都对应一个内部实现函数,通过ASSIGN_SYM宏进行绑定。
*
* @param ibvSymbols 指向scclIbvSymbols结构体的指针,用于存储函数指针
* @return scclResult_t 返回操作结果,成功返回scclSuccess
*
1. ibv_get_device_list:获取当前系统中可用的InfiniBand设备列表。
2. ibv_free_device_list:释放由 ibv_get_device_list 函数返回的设备列表。
3. ibv_get_device_name:获取与特定InfiniBand设备相关联的设备名称。
4. ibv_open_device:为特定的InfiniBand设备创建一个上下文。
5. ibv_close_device:关闭特定的InfiniBand设备上下文。
6. ibv_get_async_event:获取InfiniBand设备上下文的下一个异步事件。
7. ibv_ack_async_event:确认(ack)由 ibv_get_async_event 返回的异步事件。
8. ibv_query_device:查询特定设备的信息。
9. ibv_query_gid:查询特定设备的全局唯一标识符(GUID)。
10. ibv_query_qp:查询特定队列对(QP)的信息。
11. ibv_alloc_pd:为特定上下文分配一个保护域(PD)。
12. ibv_dealloc_pd:释放由 ibv_alloc_pd 分配的保护域(PD)。
13. ibv_reg_mr_iova2:注册内存区域(MR),支持I/O虚拟地址。
14. ibv_reg_dmabuf_mr:注册内存区域(MR),使用DMabuf。
15. ibv_dereg_mr:注销内存区域(MR)。
16. ibv_create_cq:为特定上下文创建一个完成队列(CQ)。
17. ibv_destroy_cq:销毁特定的完成队列(CQ)。
18. ibv_create_qp:为特定上下文创建一个队列对(QP)。
19. ibv_modify_qp:修改特定的队列对(QP)。
20. ibv_destroy_qp:销毁特定的队列对(QP)。
21. ibv_fork_init:初始化用于多进程支持的结构。
22. ibv_event_type_str:返回描述特定事件类型的字符串。
*/
scclResult_t buildIbvSymbols(struct scclIbvSymbols* ibvSymbols) {
ASSIGN_SYM(ibvSymbols, ibv_get_device_list, ibv_internal_get_device_list); //
ASSIGN_SYM(ibvSymbols, ibv_free_device_list, ibv_internal_free_device_list);
ASSIGN_SYM(ibvSymbols, ibv_get_device_name, ibv_internal_get_device_name);
ASSIGN_SYM(ibvSymbols, ibv_open_device, ibv_internal_open_device);
ASSIGN_SYM(ibvSymbols, ibv_close_device, ibv_internal_close_device);
ASSIGN_SYM(ibvSymbols, ibv_get_async_event, ibv_internal_get_async_event);
ASSIGN_SYM(ibvSymbols, ibv_ack_async_event, ibv_internal_ack_async_event);
ASSIGN_SYM(ibvSymbols, ibv_query_device, ibv_internal_query_device);
ASSIGN_SYM(ibvSymbols, ibv_query_gid, ibv_internal_query_gid);
ASSIGN_SYM(ibvSymbols, ibv_query_qp, ibv_internal_query_qp);
ASSIGN_SYM(ibvSymbols, ibv_alloc_pd, ibv_internal_alloc_pd);
ASSIGN_SYM(ibvSymbols, ibv_dealloc_pd, ibv_internal_dealloc_pd);
ASSIGN_SYM(ibvSymbols, ibv_reg_mr_iova2, ibv_internal_reg_mr_iova2);
ASSIGN_SYM(ibvSymbols, ibv_reg_dmabuf_mr, ibv_internal_reg_dmabuf_mr);
ASSIGN_SYM(ibvSymbols, ibv_dereg_mr, ibv_internal_dereg_mr);
ASSIGN_SYM(ibvSymbols, ibv_create_cq, ibv_internal_create_cq);
ASSIGN_SYM(ibvSymbols, ibv_destroy_cq, ibv_internal_destroy_cq);
ASSIGN_SYM(ibvSymbols, ibv_create_qp, ibv_internal_create_qp);
ASSIGN_SYM(ibvSymbols, ibv_modify_qp, ibv_internal_modify_qp);
ASSIGN_SYM(ibvSymbols, ibv_destroy_qp, ibv_internal_destroy_qp);
ASSIGN_SYM(ibvSymbols, ibv_fork_init, ibv_internal_fork_init);
ASSIGN_SYM(ibvSymbols, ibv_event_type_str, ibv_internal_event_type_str);
ibvSymbols->ibv_internal_reg_mr = &ibv_internal_reg_mr;
ibvSymbols->ibv_internal_query_port = &ibv_internal_query_port;
return scclSuccess;
}
} // namespace device
} // namespace net
} // namespace hardware
} // namespace sccl
#pragma once
#include <infiniband/verbs.h>
#include "base.h"
namespace sccl {
namespace hardware {
namespace net {
namespace device {
/* IB Verbs Function Pointers*/
struct scclIbvSymbols {
int (*ibv_internal_fork_init)(void); // 初始化fork支持
struct ibv_device** (*ibv_internal_get_device_list)(int* num_devices); // 获取设备列表
void (*ibv_internal_free_device_list)(struct ibv_device** list); // 释放设备列表
const char* (*ibv_internal_get_device_name)(struct ibv_device* device); // 获取设备名称
struct ibv_context* (*ibv_internal_open_device)(struct ibv_device* device); // 打开设备
int (*ibv_internal_close_device)(struct ibv_context* context); // 关闭设备
int (*ibv_internal_get_async_event)(struct ibv_context* context, struct ibv_async_event* event); // 获取异步事件
void (*ibv_internal_ack_async_event)(struct ibv_async_event* event); // 确认异步事件
int (*ibv_internal_query_device)(struct ibv_context* context, struct ibv_device_attr* device_attr); // 查询设备属性
int (*ibv_internal_query_port)(struct ibv_context* context, uint8_t port_num, struct ibv_port_attr* port_attr); // 查询端口属性
int (*ibv_internal_query_gid)(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid); // 查询全局标识符
int (*ibv_internal_query_qp)(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask, struct ibv_qp_init_attr* init_attr); // 查询队列对属性
struct ibv_pd* (*ibv_internal_alloc_pd)(struct ibv_context* context); // 分配保护域
int (*ibv_internal_dealloc_pd)(struct ibv_pd* pd); // 释放保护域
struct ibv_mr* (*ibv_internal_reg_mr)(struct ibv_pd* pd, void* addr, size_t length, int access); // 注册内存区域
struct ibv_mr* (*ibv_internal_reg_mr_iova2)(struct ibv_pd* pd, void* addr, size_t length, uint64_t iova, unsigned int access); // 注册内存区域(IOVA版本)
/* DMA-BUF support */
struct ibv_mr* (*ibv_internal_reg_dmabuf_mr)(struct ibv_pd* pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access); // 注册DMA-BUF内存区域
int (*ibv_internal_dereg_mr)(struct ibv_mr* mr); // 注销内存区域
struct ibv_cq* (*ibv_internal_create_cq)(
struct ibv_context* context, int cqe, void* cq_context, struct ibv_comp_channel* channel, int comp_vector); // 创建完成队列
int (*ibv_internal_destroy_cq)(struct ibv_cq* cq); // 销毁完成队列
struct ibv_qp* (*ibv_internal_create_qp)(struct ibv_pd* pd, struct ibv_qp_init_attr* qp_init_attr); // 创建队列对
int (*ibv_internal_modify_qp)(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask); // 修改队列对属性
int (*ibv_internal_destroy_qp)(struct ibv_qp* qp); // 销毁队列对
const char* (*ibv_internal_event_type_str)(enum ibv_event_type event); // 获取事件类型字符串
};
/* Constructs IB verbs symbols per rdma-core linking or dynamic loading mode */
scclResult_t buildIbvSymbols(struct scclIbvSymbols* ibvSymbols);
} // namespace device
} // namespace net
} // namespace hardware
} // namespace sccl
/*************************************************************************
* Copyright (c) 2015-2022, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include "ibvwrap.h"
#include <sys/types.h>
#include <unistd.h>
#include "ibvsymbols.h"
namespace sccl {
namespace hardware {
namespace net {
namespace device {
static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT;
static scclResult_t initResult;
struct scclIbvSymbols ibvSymbols;
/**
* 初始化并获取IB Verbs符号表
*
* 该函数使用pthread_once确保线程安全地初始化IB Verbs符号表,
* 并返回初始化结果。
*
* @return scclResult_t 返回符号表初始化结果,成功返回SCCL_SUCCESS,失败返回错误码
*/
scclResult_t wrap_ibv_symbols(void) {
pthread_once(&initOnceControl, []() { initResult = buildIbvSymbols(&ibvSymbols); });
return initResult;
}
/* CHECK_NOT_NULL: helper macro to check for NULL symbol */
#define CHECK_NOT_NULL(container, internal_name) \
if(container.internal_name == NULL) { \
WARN("lib wrapper not initialized."); \
return scclInternalError; \
}
#define IBV_PTR_CHECK_ERRNO(container, internal_name, call, retval, error_retval, name) \
CHECK_NOT_NULL(container, internal_name); \
retval = container.call; \
if(retval == error_retval) { \
WARN("Call to " name " failed with error %s", strerror(errno)); \
return scclSystemError; \
} \
return scclSuccess;
#define IBV_PTR_CHECK(container, internal_name, call, retval, error_retval, name) \
CHECK_NOT_NULL(container, internal_name); \
retval = container.call; \
if(retval == error_retval) { \
WARN("Call to " name " failed"); \
return scclSystemError; \
} \
return scclSuccess;
#define IBV_INT_CHECK_RET_ERRNO(container, internal_name, call, success_retval, name) \
CHECK_NOT_NULL(container, internal_name); \
int ret = container.call; \
if(ret != success_retval) { \
WARN("Call to " name " failed with error %s", strerror(ret)); \
return scclSystemError; \
} \
return scclSuccess;
#define IBV_INT_CHECK(container, internal_name, call, error_retval, name) \
CHECK_NOT_NULL(container, internal_name); \
int ret = container.call; \
if(ret == error_retval) { \
WARN("Call to " name " failed"); \
return scclSystemError; \
} \
return scclSuccess;
#define IBV_PASSTHRU(container, internal_name, call) \
CHECK_NOT_NULL(container, internal_name); \
container.call; \
return scclSuccess;
/**
* 初始化RDMA的fork支持
*
* 该函数封装了ibv_fork_init调用,用于启用RDMA在fork后的进程间通信支持
*
* @return 成功返回scclSuccess,失败返回错误码
*/
scclResult_t wrap_ibv_fork_init() { IBV_INT_CHECK(ibvSymbols, ibv_internal_fork_init, ibv_internal_fork_init(), -1, "ibv_fork_init"); }
scclResult_t wrap_ibv_get_device_list(struct ibv_device*** ret, int* num_devices) {
*ret = ibvSymbols.ibv_internal_get_device_list(num_devices);
if(*ret == NULL)
*num_devices = 0;
return scclSuccess;
}
scclResult_t wrap_ibv_free_device_list(struct ibv_device** list) {
IBV_PASSTHRU(ibvSymbols, ibv_internal_free_device_list, ibv_internal_free_device_list(list));
}
const char* wrap_ibv_get_device_name(struct ibv_device* device) {
if(ibvSymbols.ibv_internal_get_device_name == NULL) {
WARN("lib wrapper not initialized.");
exit(-1);
}
return ibvSymbols.ibv_internal_get_device_name(device);
}
scclResult_t wrap_ibv_open_device(struct ibv_context** ret, struct ibv_device* device) { /*returns 0 on success, -1 on failure*/
IBV_PTR_CHECK(ibvSymbols, ibv_internal_open_device, ibv_internal_open_device(device), *ret, NULL, "ibv_open_device");
}
scclResult_t wrap_ibv_close_device(struct ibv_context* context) { /*returns 0 on success, -1 on failure*/
IBV_INT_CHECK(ibvSymbols, ibv_internal_close_device, ibv_internal_close_device(context), -1, "ibv_close_device");
}
scclResult_t wrap_ibv_get_async_event(struct ibv_context* context, struct ibv_async_event* event) { /*returns 0 on success, and -1 on error*/
IBV_INT_CHECK(ibvSymbols, ibv_internal_get_async_event, ibv_internal_get_async_event(context, event), -1, "ibv_get_async_event");
}
scclResult_t wrap_ibv_ack_async_event(struct ibv_async_event* event) {
IBV_PASSTHRU(ibvSymbols, ibv_internal_ack_async_event, ibv_internal_ack_async_event(event));
}
scclResult_t
wrap_ibv_query_device(struct ibv_context* context,
struct ibv_device_attr* device_attr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_device, ibv_internal_query_device(context, device_attr), 0, "ibv_query_device");
}
scclResult_t
wrap_ibv_query_port(struct ibv_context* context,
uint8_t port_num,
struct ibv_port_attr* port_attr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_port, ibv_internal_query_port(context, port_num, port_attr), 0, "ibv_query_port");
}
scclResult_t wrap_ibv_query_gid(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid) {
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_gid, ibv_internal_query_gid(context, port_num, index, gid), 0, "ibv_query_gid");
}
scclResult_t wrap_ibv_query_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask, struct ibv_qp_init_attr* init_attr) {
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_query_qp, ibv_internal_query_qp(qp, attr, attr_mask, init_attr), 0, "ibv_query_qp");
}
scclResult_t wrap_ibv_alloc_pd(struct ibv_pd** ret, struct ibv_context* context) {
IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_alloc_pd, ibv_internal_alloc_pd(context), *ret, NULL, "ibv_alloc_pd");
}
scclResult_t wrap_ibv_dealloc_pd(struct ibv_pd* pd) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_dealloc_pd, ibv_internal_dealloc_pd(pd), 0, "ibv_dealloc_pd");
}
scclResult_t wrap_ibv_reg_mr(struct ibv_mr** ret, struct ibv_pd* pd, void* addr, size_t length, int access) {
IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_reg_mr, ibv_internal_reg_mr(pd, addr, length, access), *ret, NULL, "ibv_reg_mr");
}
struct ibv_mr* wrap_direct_ibv_reg_mr(struct ibv_pd* pd, void* addr, size_t length, int access) {
if(ibvSymbols.ibv_internal_reg_mr == NULL) {
WARN("lib wrapper not initialized.");
return NULL;
}
return ibvSymbols.ibv_internal_reg_mr(pd, addr, length, access);
}
scclResult_t wrap_ibv_reg_mr_iova2(struct ibv_mr** ret, struct ibv_pd* pd, void* addr, size_t length, uint64_t iova, int access) {
if(ibvSymbols.ibv_internal_reg_mr_iova2 == NULL) {
return scclInternalError;
}
if(ret == NULL) {
return scclSuccess;
} // Assume dummy call
IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_reg_mr_iova2, ibv_internal_reg_mr_iova2(pd, addr, length, iova, access), *ret, NULL, "ibv_reg_mr_iova2");
}
/* DMA-BUF support */
scclResult_t wrap_ibv_reg_dmabuf_mr(struct ibv_mr** ret, struct ibv_pd* pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access) {
IBV_PTR_CHECK_ERRNO(
ibvSymbols, ibv_internal_reg_dmabuf_mr, ibv_internal_reg_dmabuf_mr(pd, offset, length, iova, fd, access), *ret, NULL, "ibv_reg_dmabuf_mr");
}
struct ibv_mr* wrap_direct_ibv_reg_dmabuf_mr(struct ibv_pd* pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access) {
if(ibvSymbols.ibv_internal_reg_dmabuf_mr == NULL) {
errno = EOPNOTSUPP; // scclIbDmaBufSupport() requires this errno being set
return NULL;
}
return ibvSymbols.ibv_internal_reg_dmabuf_mr(pd, offset, length, iova, fd, access);
}
scclResult_t wrap_ibv_dereg_mr(struct ibv_mr* mr) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_dereg_mr, ibv_internal_dereg_mr(mr), 0, "ibv_dereg_mr");
}
scclResult_t
wrap_ibv_create_cq(struct ibv_cq** ret, struct ibv_context* context, int cqe, void* cq_context, struct ibv_comp_channel* channel, int comp_vector) {
IBV_PTR_CHECK_ERRNO(
ibvSymbols, ibv_internal_create_cq, ibv_internal_create_cq(context, cqe, cq_context, channel, comp_vector), *ret, NULL, "ibv_create_cq");
}
scclResult_t wrap_ibv_destroy_cq(struct ibv_cq* cq) {
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_destroy_cq, ibv_internal_destroy_cq(cq), 0, "ibv_destroy_cq");
}
scclResult_t wrap_ibv_destroy_qp(struct ibv_qp* qp) {
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_destroy_qp, ibv_internal_destroy_qp(qp), 0, "ibv_destroy_qp");
}
scclResult_t wrap_ibv_create_qp(struct ibv_qp** ret, struct ibv_pd* pd, struct ibv_qp_init_attr* qp_init_attr) {
IBV_PTR_CHECK_ERRNO(ibvSymbols, ibv_internal_create_qp, ibv_internal_create_qp(pd, qp_init_attr), *ret, NULL, "ibv_create_qp");
}
scclResult_t wrap_ibv_modify_qp(struct ibv_qp* qp,
struct ibv_qp_attr* attr,
int attr_mask) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
IBV_INT_CHECK_RET_ERRNO(ibvSymbols, ibv_internal_modify_qp, ibv_internal_modify_qp(qp, attr, attr_mask), 0, "ibv_modify_qp");
}
scclResult_t wrap_ibv_event_type_str(char** ret, enum ibv_event_type event) {
*ret = (char*)ibvSymbols.ibv_internal_event_type_str(event);
return scclSuccess;
}
scclResult_t wrap_ibv_poll_cq(struct ibv_cq* cq, int num_entries, struct ibv_wc* wc, int* num_done) {
int done = cq->context->ops.poll_cq(cq, num_entries, wc); /*returns the number of wcs or 0 on success, a negative number otherwise*/
if(done < 0) {
WARN("Call to ibv_poll_cq() returned %d", done);
return scclSystemError;
}
*num_done = done;
return scclSuccess;
}
scclResult_t wrap_ibv_post_send(struct ibv_qp* qp, struct ibv_send_wr* wr, struct ibv_send_wr** bad_wr) {
int ret = qp->context->ops.post_send(qp, wr, bad_wr); /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
if(ret != IBV_SUCCESS) {
WARN("ibv_post_send() failed with error %s, Bad WR %p, First WR %p", strerror(ret), wr, *bad_wr);
return scclSystemError;
}
return scclSuccess;
}
scclResult_t wrap_ibv_post_recv(struct ibv_qp* qp, struct ibv_recv_wr* wr, struct ibv_recv_wr** bad_wr) {
int ret = qp->context->ops.post_recv(qp, wr, bad_wr); /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/
if(ret != IBV_SUCCESS) {
WARN("ibv_post_recv() failed with error %s", strerror(ret));
return scclSystemError;
}
return scclSuccess;
}
} // namespace device
} // namespace net
} // namespace hardware
} // namespace sccl
#pragma once
#include <infiniband/verbs.h>
#include <sys/types.h>
#include <unistd.h>
#include "base.h"
#include "ibvsymbols.h"
namespace sccl {
namespace hardware {
namespace net {
namespace device {
typedef enum ibv_return_enum : uint8_t {
IBV_SUCCESS = 0, //!< The operation was successful
} ibv_return_t;
// 包装ibv符号初始化
scclResult_t wrap_ibv_symbols(void);
// 包装ibv函数的SCCL初始化
scclResult_t wrap_ibv_fork_init(void);
// 获取设备列表
scclResult_t wrap_ibv_get_device_list(struct ibv_device*** ret, int* num_devices);
// 释放设备列表
scclResult_t wrap_ibv_free_device_list(struct ibv_device** list);
// 获取设备名称
const char* wrap_ibv_get_device_name(struct ibv_device* device);
// 打开设备
scclResult_t wrap_ibv_open_device(struct ibv_context** ret, struct ibv_device* device);
// 关闭设备
scclResult_t wrap_ibv_close_device(struct ibv_context* context);
// 获取异步事件
scclResult_t wrap_ibv_get_async_event(struct ibv_context* context, struct ibv_async_event* event);
// 确认异步事件
scclResult_t wrap_ibv_ack_async_event(struct ibv_async_event* event);
// 查询设备属性
scclResult_t wrap_ibv_query_device(struct ibv_context* context, struct ibv_device_attr* device_attr);
// 查询端口属性
scclResult_t wrap_ibv_query_port(struct ibv_context* context, uint8_t port_num, struct ibv_port_attr* port_attr);
// 查询GID
scclResult_t wrap_ibv_query_gid(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid);
// 查询QP属性
scclResult_t wrap_ibv_query_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask, struct ibv_qp_init_attr* init_attr);
// 分配保护域(PD)
scclResult_t wrap_ibv_alloc_pd(struct ibv_pd** ret, struct ibv_context* context);
// 释放保护域(PD)
scclResult_t wrap_ibv_dealloc_pd(struct ibv_pd* pd);
// 注册内存区域(MR)
scclResult_t wrap_ibv_reg_mr(struct ibv_mr** ret, struct ibv_pd* pd, void* addr, size_t length, int access);
// 直接注册内存区域(MR)
struct ibv_mr* wrap_direct_ibv_reg_mr(struct ibv_pd* pd, void* addr, size_t length, int access);
// 使用IOVA地址注册内存区域(MR)
scclResult_t wrap_ibv_reg_mr_iova2(struct ibv_mr** ret, struct ibv_pd* pd, void* addr, size_t length, uint64_t iova, int access);
// 注册DMA-BUF内存区域(MR)
scclResult_t wrap_ibv_reg_dmabuf_mr(struct ibv_mr** ret, struct ibv_pd* pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access);
// 直接注册DMA-BUF内存区域(MR)
struct ibv_mr* wrap_direct_ibv_reg_dmabuf_mr(struct ibv_pd* pd, uint64_t offset, size_t length, uint64_t iova, int fd, int access);
// 注销内存区域(MR)
scclResult_t wrap_ibv_dereg_mr(struct ibv_mr* mr);
// 创建完成通道(CQ)
scclResult_t wrap_ibv_create_comp_channel(struct ibv_comp_channel** ret, struct ibv_context* context);
// 销毁完成通道(CQ)
scclResult_t wrap_ibv_destroy_comp_channel(struct ibv_comp_channel* channel);
// 创建完成队列(CQ)
scclResult_t wrap_ibv_create_cq(struct ibv_cq** ret, struct ibv_context* context, int cqe, void* cq_context, struct ibv_comp_channel* channel, int comp_vector);
// 销毁完成队列(CQ)
scclResult_t wrap_ibv_destroy_cq(struct ibv_cq* cq);
// 轮询完成队列(CQ)
scclResult_t wrap_ibv_poll_cq(struct ibv_cq* cq, int num_entries, struct ibv_wc* wc, int* num_done);
// 创建队列对(QP)
scclResult_t wrap_ibv_create_qp(struct ibv_qp** ret, struct ibv_pd* pd, struct ibv_qp_init_attr* qp_init_attr);
// 修改队列对(QP)属性
scclResult_t wrap_ibv_modify_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask);
// 销毁队列对(QP)
scclResult_t wrap_ibv_destroy_qp(struct ibv_qp* qp);
// 发送数据
scclResult_t wrap_ibv_post_send(struct ibv_qp* qp, struct ibv_send_wr* wr, struct ibv_send_wr** bad_wr);
// 接收数据
scclResult_t wrap_ibv_post_recv(struct ibv_qp* qp, struct ibv_recv_wr* wr, struct ibv_recv_wr** bad_wr);
// 获取事件类型字符串
scclResult_t wrap_ibv_event_type_str(char** ret, enum ibv_event_type event);
} // namespace device
} // namespace net
} // namespace hardware
} // namespace sccl
#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <sys/types.h>
#include <unistd.h>
#include <netdb.h>
#include "net_ib.h"
#include "socket.h"
#include "rocm_wrap.h"
#include "base.h"
namespace sccl {
namespace hardware {
namespace net {
namespace device {
namespace net_ib {
///////////////////////////////////////// 环境变量读取及设置 /////////////////////////////////////////
// 定义InfiniBand GID索引,默认值为0
SCCL_PARAM(IbGidIndex, "IB_GID_INDEX", 0);
// 定义InfiniBand超时时间,默认值为18
SCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18);
// 定义InfiniBand重试次数,默认值为7
SCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7);
// 定义InfiniBand分区密钥,默认值为0
SCCL_PARAM(IbPkey, "IB_PKEY", 0);
// 定义是否使用InfiniBand内联传输,默认值为0(不使用)
SCCL_PARAM(IbUseInline, "IB_USE_INLINE", 0);
// 定义InfiniBand服务级别,默认值为0
SCCL_PARAM(IbSl, "IB_SL", 0);
// 定义InfiniBand流量类别,默认值为0
SCCL_PARAM(IbTc, "IB_TC", 0);
// 定义InfiniBand自动路由阈值,默认值为8192
SCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192);
// 定义InfiniBand PCI宽松排序选项,默认值为2
SCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2);
// 定义是否启用InfiniBand自适应路由,默认值为-2(可能表示禁用或默认设置)
SCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2);
// 定义InfiniBand套接字客户端端口重用选项,默认值为0(不重用)
SCCL_PARAM(IbSockClientPortReuse, "IB_SOCK_CLIENT_PORT_REUSE", 0);
// 定义InfiniBand套接字服务器端口重用选项,默认值为0(不重用)
SCCL_PARAM(IbSockServerPortReuse, "IB_SOCK_SERVER_PORT_REUSE", 0);
// 定义是否禁用InfiniBand,默认值为0(不禁用)
SCCL_PARAM(IbDisable, "IB_DISABLE", 0);
// 定义是否合并InfiniBand虚拟功能,默认值为1(合并)
SCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1);
// 定义每个连接的InfiniBand队列对(QP)数量,默认值为1
SCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1);
// 定义是否禁用GDR刷新,默认值为0(不禁用)
SCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0);
// 定义是否在队列对上分割数据,默认值为1(分割)
SCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 1);
///////////////////////////////////////// 参数及结构体设置 /////////////////////////////////////////
#define MAXNAMESIZE 64
#define MAX_IF_NAME_SIZE 16
static char scclIbIfName[MAX_IF_NAME_SIZE + 1];
static union host::scclSocketAddress scclIbIfAddr;
// 定义一个静态变量 scclNIbDevs,用于存储 InfiniBand 设备的数量
static int scclNIbDevs = -1;
struct scclIbMr {
uintptr_t addr; // 内存地址
int pages; // 页数
int refs; // 引用计数
ibv_mr* mr; // InfiniBand内存注册对象指针
};
// 结构体用于缓存 InfiniBand 内存注册对象
struct scclIbMrCache {
struct scclIbMr* slots; // 缓存槽,用于存储内存注册对象
int capacity;
int population; // 缓存的容量和当前已填充的数量
};
// 定义一个对齐到 64 字节边界的结构体 scclIbDev,用于表示 InfiniBand 设备
struct alignas(64) scclIbDev {
pthread_mutex_t lock; // 互斥锁,用于线程同步
int device; // 设备编号
uint64_t guid; // 全局唯一标识符
uint8_t port; // 端口号
uint8_t link; // 链路层信息
int speed; // 传输速度
ibv_context* context; // InfiniBand 上下文
int pdRefs; // 保护域引用计数
ibv_pd* pd; // 保护域
char devName[MAXNAMESIZE]; // 设备名称
char* pciPath; // PCI 路径
int realPort; // 实际使用的端口
int maxQp; // 最大队列对数量
struct scclIbMrCache mrCache; // 内存注册对象缓存
int ar; // ADAPTIVE_ROUTING,自适应路由标志
};
struct userIbDev {
char devName[MAXNAMESIZE];
uint16_t port_en;
};
// 定义最大InfiniBand设备数量为16
static constexpr int MAX_IB_DEVS = 16;
// 定义一个结构体数组,用于存储InfiniBand设备信息
struct scclIbDev scclIbDevs[MAX_IB_DEVS];
// 定义一个结构体数组,用于存储用户级别的InfiniBand设备信息
struct userIbDev userIbDevs[MAX_IB_DEVS];
// 定义一个互斥锁,用于保护对InfiniBand设备的并发访问
pthread_mutex_t scclIbLock = PTHREAD_MUTEX_INITIALIZER;
// 定义一个静态整数,用于指示是否启用了InfiniBand的Relaxed Ordering模式
static int scclIbRelaxedOrderingEnabled = 0;
// 定义一个线程局部变量,用于存储重用的地址信息
static thread_local union host::scclSocketAddress reusedAddr;
// 定义一个线程局部变量,用于存储重用的套接字文件描述符
static thread_local int reusedSockfd = -1;
// 定义一个线程ID,用于异步线程操作
pthread_t scclIbAsyncThread;
// 定义一个常量,表示InfiniBand网络接口的最大接收数量
static constexpr int SCCL_NET_IB_MAX_RECVS = 8;
// 定义一个常量,表示最大字符串长度
static constexpr int MAX_STR_LEN = 8;
// 为每个并发接收支持SCCL_NET_MAX_REQUESTS
static constexpr int MAX_REQUESTS = (SCCL_NET_MAX_REQUESTS * SCCL_NET_IB_MAX_RECVS);
static_assert(MAX_REQUESTS <= 256, "request id are encoded in wr_id and we need up to 8 requests ids per completion");
// Retain local and remote RoCE addresses for error logging
struct scclIbGidInfo {
uint8_t link_layer; // 链路层类型,表示网络连接的物理层类型
union ibv_gid localGid; // 本地设备的全局标识符(GID)
union ibv_gid remoteGid; // 远程设备的全局标识符(GID)
};
/*
scclIbRequest 结构体用于封装 InfiniBand 通信请求的详细信息,包括通信接口、请求类型、数据缓冲区等。
联合体 union 根据请求类型(发送或接收)存储不同的数据结构,以支持灵活的通信操作。
*/
struct scclIbRequest {
struct scclIbVerbs* verbs; // 指向 scclIbVerbs 结构体的指针,包含 Infiniband 相关的操作
int type; // 请求的类型,例如发送或接收
int events; // 事件标志, 用于记录请求相关的事件状态
struct host::scclSocket* sock; // 指向 scclSocket 结构体的指针,表示网络套接字
struct scclIbGidInfo* gidInfo; // 指向 scclIbGidInfo 结构体的指针,包含全局标识符信息
int nreqs; // 请求的数量
// 联合体,用于存储不同类型请求的特定信息
union {
// send: 发送请求的相关信息
struct {
int size; // 发送数据的大小
void* data; // 指向发送数据的指针
uint32_t lkey; // 本地密钥,用于数据访问
int offset; // 数据偏移量
} send;
// recv: 接收请求的相关信息
struct {
int sizes[SCCL_NET_IB_MAX_RECVS]; // 接收数据的大小数组,最多包含 SCCL_NET_IB_MAX_RECVS 个元素
} recv;
};
};
/*用于封装 InfiniBand 通信所需的资源,便于管理和复用。*/
struct scclIbVerbs {
int dev; // 设备索引,标识使用的 InfiniBand 设备
struct ibv_pd* pd; // 指向 InfiniBand 保护域(Protection Domain)的指针,用于内存注册和队列管理
struct ibv_cq* cq; // 指向 InfiniBand 完成队列(Completion Queue)的指针,用于跟踪异步操作的状态
uint64_t pad[1]; // 填充字段,可能用于内存对齐或未来扩展
struct scclIbRequest reqs[MAX_REQUESTS]; // 存储最大请求数(MAX_REQUESTS)的请求结构体数组
};
/*用于 InfiniBand 通信的发送队列(FIFO),存储待发送数据的元信息,供底层网络驱动或通信库使用*/
struct alignas(64) scclIbSendFifo {
uint64_t addr; // 目标内存地址(远程地址)
int size; // 发送数据的大小(字节)
uint32_t rkey; // 远程密钥(Remote Key),用于 InfiniBand 的远程内存访问(RMA)
uint32_t nreqs; // 发送请求的数量(可能用于批量操作)
uint32_t tag; // 标签或标识符,用于区分不同的发送操作
uint64_t idx; // 索引值,可能用于跟踪或管理发送队列中的位置
};
static constexpr int SCCL_IB_MAX_QPS = 128; // 最大队列对数量
struct scclIbSendComm {
struct scclIbVerbs verbs; // RDMA verbs结构体
struct scclIbSendFifo fifo[MAX_REQUESTS][SCCL_NET_IB_MAX_RECVS]; // 发送FIFO队列
uint64_t fifoHead; // FIFO队列头指针
struct scclIbRequest* fifoReqs[MAX_REQUESTS][SCCL_NET_IB_MAX_RECVS]; // FIFO请求指针数组
struct ibv_send_wr wrs[SCCL_NET_IB_MAX_RECVS + 1]; // 发送工作请求结构体数组
struct ibv_sge sges[SCCL_NET_IB_MAX_RECVS]; // 散布-聚集元素结构体数组
struct host::scclSocket sock; // 套接字结构体
int ready; // 是否准备好
struct ibv_qp* qps[SCCL_IB_MAX_QPS]; // 队列对指针数组
int nqps; // 队列对数量
int qpIndex; // 当前队列对索引
struct ibv_mr* fifoMr; // FIFO内存区域指针
int ar; // 自动重发标志
struct scclIbGidInfo gidInfo; // GID信息结构体
};
/*IB的通信状态*/
enum scclIbCommState : uint8_t {
scclIbCommStateStart = 0, // 初始状态
scclIbCommStateConnect = 1, // 尝试连接状态
scclIbCommStateAccept = 3, // 接受连接状态
scclIbCommStateSend = 4, // 发送数据状态
scclIbCommStateRecv = 5, // 接收数据状态
scclIbCommStateConnecting = 6, // 正在连接状态
scclIbCommStateConnected = 7, // 已连接状态
scclIbCommStatePendingReady = 8, // 等待准备状态
};
/*通信的阶段*/
struct scclIbCommStage {
enum scclIbCommState state; // 通信阶段的状态
int offset; // 数据偏移量
void* buffer; // 用于通信的缓冲区指针
void* comm; // 通信对象指针
};
/*监听通信的上下文*/
struct scclIbListenComm {
int dev; // 设备标识符
struct host::scclSocket sock; // 用于网络通信的套接字
struct scclIbCommStage stage; // 通信阶段的状态
};
struct scclIbQpInfo {
uint32_t lid;
uint8_t ib_port;
uint8_t link_layer;
uint32_t qpn[SCCL_IB_MAX_QPS];
// For RoCE
uint64_t spn;
uint64_t iid;
enum ibv_mtu mtu;
// FIFO RDMA info
uint32_t fifoRkey;
uint64_t fifoAddr;
};
struct scclIbGpuFlush {
int enabled;
int hostMem;
struct ibv_mr* hostMr;
struct ibv_sge sge;
struct ibv_qp* qp;
};
struct scclIbRemFifo {
struct scclIbSendFifo elems[MAX_REQUESTS][SCCL_NET_IB_MAX_RECVS];
uint64_t fifoTail;
uint64_t addr;
uint32_t rkey;
uint32_t flags;
struct ibv_mr* mr;
struct ibv_sge sge;
};
struct scclIbRecvComm {
struct scclIbVerbs verbs;
struct scclIbRemFifo remFifo;
struct host::scclSocket sock;
int ready;
struct ibv_qp* qps[SCCL_IB_MAX_QPS];
int nqps;
int qpIndex;
struct scclIbGpuFlush gpuFlush;
struct scclIbGidInfo gidInfo;
};
static_assert((offsetof(struct scclIbRecvComm, remFifo) % 32) == 0, "scclIbSendComm fifo must be 32-byte aligned");
///////////////////////////////////////// net_ib的函数 /////////////////////////////////////////
/**
* @brief IB异步事件处理线程主函数
*
* 该函数作为独立线程运行,持续监听并处理IB设备的异步事件。
* 对于每个接收到的异步事件(除IBV_EVENT_COMM_EST外),会输出警告日志。
* 处理完成后必须调用wrap_ibv_ack_async_event进行事件确认。
*
* @param args 传入参数,应转换为ibv_context结构体指针
* @return void* 线程返回值,始终返回NULL
*/
static void* scclIbAsyncThreadMain(void* args) {
// 将传入的参数转换为InfiniBand上下文结构体指针
struct ibv_context* context = (struct ibv_context*)args;
// 无限循环,持续监听异步事件
while(1) {
// 定义一个结构体来存储异步事件
struct ibv_async_event event;
// 调用封装的函数获取异步事件,如果获取失败则退出循环
if(scclSuccess != wrap_ibv_get_async_event(context, &event)) {
break;
}
// 定义一个字符指针用于存储事件类型的字符串描述
char* str;
// 调用封装的函数将事件类型转换为字符串,如果转换失败则退出循环
if(scclSuccess != wrap_ibv_event_type_str(&str, event.event_type)) {
break;
}
// 如果事件类型不是通信建立事件,则输出警告信息
if(event.event_type != IBV_EVENT_COMM_EST)
WARN("NET/IB : Got async event : %s", str);
// 调用封装的函数确认(acknowledge)异步事件,如果确认失败则退出循环
if(scclSuccess != wrap_ibv_ack_async_event(&event)) {
break;
}
}
// 线程结束,返回NULL
return NULL;
}
/**
* @brief 获取IB设备的PCI路径并处理多端口和虚拟功能合并
*
* 该函数通过设备名称获取IB设备的真实PCI路径,并对多端口NIC和虚拟功能(VF)进行合并处理,
* 将它们视为同一PCI设备。同时记录实际端口号。
*
* @param devName 输入参数,IB设备名称
* @param path 输出参数,存储获取到的PCI路径
* @param realPort 输出参数,记录实际端口号
* @return scclResult_t 返回操作结果,成功返回scclSuccess
*/
static scclResult_t scclIbGetPciPath(char* devName, char** path, int* realPort) {
// 定义一个字符数组用于存储设备路径
char devicePath[PATH_MAX];
// 构造设备路径字符串,格式为 "/sys/class/infiniband/<devName>/device"
snprintf(devicePath, PATH_MAX, "/sys/class/infiniband/%s/device", devName);
// 获取设备路径的绝对路径
char* p = realpath(devicePath, NULL);
if(p == NULL) {
// 如果无法获取绝对路径,记录警告信息
WARN("Could not find real path of %s (%s)", devName, devicePath);
} else {
// 处理多端口 NIC(网络接口卡),将路径末尾的端口编号替换为 '0'
p[strlen(p) - 1] = '0';
// 如果启用了虚拟函数(VF)合并,则将路径中倒数第3和第4字符替换为 '0'
if(scclParamIbMergeVfs())
p[strlen(p) - 3] = p[strlen(p) - 4] = '0';
// 初始化 realPort 为 0,用于统计实际端口数量
*realPort = 0;
// 遍历已知的 InfiniBand 设备列表
for(int d = 0; d < scclNIbDevs; d++) {
// 如果当前路径与已知的设备 PCI 路径匹配,则增加实际端口计数
if(strcmp(p, scclIbDevs[d].pciPath) == 0)
(*realPort)++;
}
}
// 将计算得到的绝对路径赋值给输出参数 path
*path = p;
// 返回成功状态
return scclSuccess;
}
static int ibvWidths[] = {1, 4, 8, 12, 2};
static int ibvSpeeds[] = {2500, /* SDR */
5000, /* DDR */
10000, /* QDR */
10000, /* QDR */
14000, /* FDR */
25000, /* EDR */
50000, /* HDR */
100000 /* NDR */};
/**
* 查找第一个被设置的bit位
* @param val 要检查的整数值
* @param max 最大检查位数
* @return 第一个被设置的bit位索引,若未找到则返回max
*/
static int firstBitSet(int val, int max) {
int i = 0;
while(i < max && ((val & (1 << i)) == 0))
i++;
return i;
}
/**
* 根据输入的宽度值,返回对应的IB(InfiniBand)链路宽度索引
* @param width 输入的宽度值
* @return 返回ibvWidths数组中对应的宽度索引值
*/
static int scclIbWidth(int width) { return ibvWidths[firstBitSet(width, sizeof(ibvWidths) / sizeof(int) - 1)]; }
/**
* 根据给定的速度值查找并返回对应的IB传输速率
* @param speed 输入的速度值
* @return 返回ibvSpeeds数组中第一个匹配的IB传输速率
*/
static int scclIbSpeed(int speed) { return ibvSpeeds[firstBitSet(speed, sizeof(ibvSpeeds) / sizeof(int) - 1)]; }
/**
* 检查当前IB设备是否支持宽松排序(Relaxed Ordering)模式
*
* @return 1表示支持,0表示不支持
* @note 通过查询IBVERBS_1.8 API的ibv_reg_mr_iova2函数来检测IBV_ACCESS_RELAXED_ORDERING支持
* @see scclParamIbPciRelaxedOrdering() 获取当前配置的RO模式
*/
static int scclIbRelaxedOrderingCapable(void) {
int roMode = scclParamIbPciRelaxedOrdering();
scclResult_t r = scclInternalError;
if(roMode == 1 || roMode == 2) {
// Query IBVERBS_1.8 API - needed for IBV_ACCESS_RELAXED_ORDERING support
r = wrap_ibv_reg_mr_iova2(NULL, NULL, NULL, 0, 0, 0);
}
return r == scclInternalError ? 0 : 1;
}
/**
* @brief 获取并处理用户指定的IB设备环境变量
*
* 该函数检查并处理环境变量SCCL_IB_HCA的值,支持以下特殊前缀:
* - '^' 表示反向匹配
* - '=' 表示精确匹配
*
* @param shownIbHcaEnv 计数器,用于控制日志输出次数
* @return char* 处理后的IB设备环境变量值
*/
static char* scclIbGetIbHca(int& shownIbHcaEnv, bool* searchNot, bool* searchExact) {
// 检查用户是否定义了要使用的IB设备:端口
char* userIbEnv = getenv("SCCL_IB_HCA");
if(userIbEnv != NULL && shownIbHcaEnv++ == 0)
INFO(SCCL_LOG_NET, "SCCL_IB_HCA set to %s", userIbEnv);
*searchNot = userIbEnv && userIbEnv[0] == '^';
if(*searchNot)
userIbEnv++;
*searchExact = userIbEnv && userIbEnv[0] == '=';
if(*searchExact)
userIbEnv++;
return userIbEnv;
}
/**
* @brief 从系统文件中读取字符串内容
*
* 该函数通过拼接路径和文件名,打开指定文件并读取其内容到字符串缓冲区中。
* 如果读取失败或文件为空,会将缓冲区置为空字符串并记录警告信息。
*
* @param path 文件所在目录路径
* @param fileName 要读取的文件名
* @param strValue 用于存储读取内容的字符串缓冲区
* @return scclResult_t 始终返回scclSuccess
*
* @note 缓冲区最大长度为MAX_STR_LEN,超出部分会被截断
* 文件内容末尾会自动添加字符串结束符'\0'
*/
scclResult_t scclGetStrFromSys(const char* path, const char* fileName, char* strValue) {
char filePath[PATH_MAX];
sprintf(filePath, "%s/%s", path, fileName);
int offset = 0;
FILE* file;
if((file = fopen(filePath, "r")) != NULL) {
while(feof(file) == 0 && ferror(file) == 0 && offset < MAX_STR_LEN) {
int len = fread(strValue + offset, 1, MAX_STR_LEN - offset, file);
offset += len;
}
fclose(file);
}
if(offset == 0) {
strValue[0] = '\0';
INFO(SCCL_LOG_NET, "System detection : could not read %s, ignoring", filePath);
} else {
strValue[offset - 1] = '\0';
}
return scclSuccess;
}
/**
* @brief 检查IB设备是否支持GPU Direct RDMA (GDR)
*
* 该函数用于检测当前系统环境是否支持GPU Direct RDMA功能。
* 在HIP平台下会检查内核模块加载状态、BIOS版本和NUMA平衡设置,
* 其他平台默认不支持。
*
* @param ibDev IB设备号
* @return scclResult_t 返回scclSuccess表示支持,返回scclSystemError表示不支持
*/
scclResult_t scclIbGdrSupport(int ibDev) {
static int moduleLoaded = -1;
if(moduleLoaded == -1) {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
moduleLoaded = (access("/sys/kernel/mm/memory_peers/amdkfd/version", F_OK) == -1) ? 0 : 1;
char strValue[MAX_STR_LEN];
SCCLCHECK(scclGetStrFromSys("/sys/devices/virtual/dmi/id", "bios_version", strValue));
if(strncmp("Hyper-V UEFI Release", strValue, 20) == 0) {
int roMode = scclParamIbPciRelaxedOrdering();
SCCLCHECK(scclGetStrFromSys("/proc/sys/kernel", "numa_balancing", strValue));
if(strcmp(strValue, "1") == 0 && roMode == 0)
moduleLoaded = 0;
}
#else
moduleLoaded = 0;
#endif
}
if(moduleLoaded == 0)
return scclSystemError;
return scclSuccess;
}
/**
* @brief 检查设备是否支持DMA-BUF功能
*
* 该函数用于检测指定IB设备是否支持DMA-BUF内存注册功能。
* 通过尝试注册一个无效的DMA-BUF文件描述符来测试支持性。
* 结果会被缓存以避免重复检测。
*
* @param dev 设备索引
* @return scclResult_t 返回scclSuccess表示支持,scclSystemError表示不支持
*/
scclResult_t scclIbDmaBufSupport(int dev) {
static int dmaBufSupported = -1;
if(dmaBufSupported == -1) {
scclResult_t res;
SCCLCHECKGOTO(rocmLibraryInit(), res, failure);
struct ibv_pd* pd;
struct ibv_context* ctx;
ctx = scclIbDevs[dev].context;
SCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure);
// Test kernel DMA-BUF support with a dummy call (fd=-1)
(void)wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/);
// ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise)
dmaBufSupported = (errno != EOPNOTSUPP && errno != EPROTONOSUPPORT) ? 1 : 0;
SCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure);
}
if(dmaBufSupported == 0)
return scclSystemError;
return scclSuccess;
failure:
dmaBufSupported = 0;
return scclSystemError;
}
struct scclIbHandle {
union host::scclSocketAddress connectAddr; // Filled by the target (目标填充)
uint64_t magic; // random number to help debugging (用于调试的随机数)
struct scclIbCommStage stage; // Used by the other side when connecting (连接时由另一侧使用)
};
/**
* @brief 初始化InfiniBand Verbs资源
*
* 该函数用于初始化指定设备的InfiniBand Verbs资源,包括:
* - 分配保护域(PD)
* - 创建完成队列(CQ)
*
* @param dev 设备索引
* @param ctx IB设备上下文
* @param verbs 要初始化的Verbs结构体指针
* @return scclResult_t 返回操作结果,scclSuccess表示成功
*
* @note 该函数会递增设备的PD引用计数,并在首次调用时为设备分配PD
* @note 创建的CQ大小为2*MAX_REQUESTS*IB_QPS_PER_CONNECTION,以支持接收请求的双重完成
*/
scclResult_t scclIbInitVerbs(int dev, struct ibv_context* ctx, struct scclIbVerbs* verbs) {
verbs->dev = dev;
pthread_mutex_lock(&scclIbDevs[dev].lock);
if(0 == scclIbDevs[dev].pdRefs++) {
scclResult_t res;
SCCLCHECKGOTO(wrap_ibv_alloc_pd(&scclIbDevs[dev].pd, ctx), res, failure);
if(0) {
failure:
pthread_mutex_unlock(&scclIbDevs[dev].lock);
return res;
}
}
verbs->pd = scclIbDevs[dev].pd;
pthread_mutex_unlock(&scclIbDevs[dev].lock);
// Recv requests can generate 2 completions (one for the post FIFO, one for the Recv).
SCCLCHECK(wrap_ibv_create_cq(&verbs->cq, ctx, 2 * MAX_REQUESTS * scclParamIbQpsPerConn(), NULL, NULL, 0));
return scclSuccess;
}
scclResult_t scclIbCreateQp(uint8_t ib_port, struct scclIbVerbs* verbs, int access_flags, struct ibv_qp** qp) {
struct ibv_qp_init_attr qpInitAttr;
memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr));
qpInitAttr.send_cq = verbs->cq;
qpInitAttr.recv_cq = verbs->cq;
qpInitAttr.qp_type = IBV_QPT_RC;
// We might send 2 messages per send (RDMA and RDMA_WITH_IMM)
qpInitAttr.cap.max_send_wr = 2 * MAX_REQUESTS;
qpInitAttr.cap.max_recv_wr = MAX_REQUESTS;
qpInitAttr.cap.max_send_sge = 1;
qpInitAttr.cap.max_recv_sge = 1;
qpInitAttr.cap.max_inline_data = scclParamIbUseInline() ? sizeof(struct scclIbSendFifo) : 0;
SCCLCHECK(wrap_ibv_create_qp(qp, verbs->pd, &qpInitAttr));
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_INIT;
qpAttr.pkey_index = scclParamIbPkey();
qpAttr.port_num = ib_port;
qpAttr.qp_access_flags = access_flags;
SCCLCHECK(wrap_ibv_modify_qp(*qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS));
return scclSuccess;
}
scclResult_t scclIbRtrQp(struct ibv_qp* qp, uint32_t qpn, struct scclIbQpInfo* info) {
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_RTR;
qpAttr.path_mtu = info->mtu;
qpAttr.dest_qp_num = qpn;
qpAttr.rq_psn = 0;
qpAttr.max_dest_rd_atomic = 1;
qpAttr.min_rnr_timer = 12;
if(info->link_layer == IBV_LINK_LAYER_ETHERNET) {
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid;
qpAttr.ah_attr.grh.flow_label = 0;
qpAttr.ah_attr.grh.sgid_index = scclParamIbGidIndex();
qpAttr.ah_attr.grh.hop_limit = 255;
qpAttr.ah_attr.grh.traffic_class = scclParamIbTc();
} else {
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
}
qpAttr.ah_attr.sl = scclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
SCCLCHECK(wrap_ibv_modify_qp(
qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER));
return scclSuccess;
}
scclResult_t scclIbRtsQp(struct ibv_qp* qp) {
struct ibv_qp_attr qpAttr;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_RTS;
qpAttr.timeout = scclParamIbTimeout();
qpAttr.retry_cnt = scclParamIbRetryCnt();
qpAttr.rnr_retry = 7;
qpAttr.sq_psn = 0;
qpAttr.max_rd_atomic = 1;
SCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC));
return scclSuccess;
}
#define SCCL_NET_IB_REQ_UNUSED 0
#define SCCL_NET_IB_REQ_SEND 1
#define SCCL_NET_IB_REQ_RECV 2
#define SCCL_NET_IB_REQ_FLUSH 3
const char* reqTypeStr[] = {"Unused", "Send", "Recv", "Flush"};
// The SendFifo needs to be 32-byte aligned and each element needs
// to be a 32-byte multiple, so that an entry does not get split and
// written out of order when IB Relaxed Ordering is enabled
static_assert((offsetof(struct scclIbSendComm, fifo) % 32) == 0, "scclIbSendComm fifo must be 32-byte aligned");
static_assert((sizeof(struct scclIbSendFifo) % 32) == 0, "scclIbSendFifo element size must be 32-byte multiples");
scclResult_t scclIbDestroyVerbs(struct scclIbVerbs* verbs) {
scclResult_t res;
SCCLCHECK(wrap_ibv_destroy_cq(verbs->cq));
pthread_mutex_lock(&scclIbDevs[verbs->dev].lock);
if(0 == --scclIbDevs[verbs->dev].pdRefs) {
SCCLCHECKGOTO(wrap_ibv_dealloc_pd(scclIbDevs[verbs->dev].pd), res, returning);
}
res = scclSuccess;
returning:
pthread_mutex_unlock(&scclIbDevs[verbs->dev].lock);
return res;
}
scclResult_t scclIbGetRequest(struct scclIbVerbs* verbs, struct scclIbRequest** req) {
for(int i = 0; i < MAX_REQUESTS; i++) {
struct scclIbRequest* r = verbs->reqs + i;
if(r->type == SCCL_NET_IB_REQ_UNUSED) {
r->verbs = verbs;
r->events = 1;
r->sock = NULL;
r->gidInfo = NULL;
*req = r;
return scclSuccess;
}
}
WARN("NET/IB : unable to allocate requests");
*req = NULL;
return scclInternalError;
}
scclResult_t scclIbFreeRequest(struct scclIbRequest* r) {
r->type = SCCL_NET_IB_REQ_UNUSED;
return scclSuccess;
}
scclResult_t scclIbTest(void* request, int* done, int* size);
scclResult_t scclIbMultiSend(struct scclIbSendComm* comm, int slot) {
struct scclIbRequest** reqs = comm->fifoReqs[slot];
volatile struct scclIbSendFifo* slots = comm->fifo[slot];
int nreqs = slots[0].nreqs;
if(nreqs > SCCL_NET_IB_MAX_RECVS)
return scclInternalError;
uint64_t wr_id = 0ULL;
for(int r = 0; r < nreqs; r++) {
struct ibv_send_wr* wr = comm->wrs + r;
memset(wr, 0, sizeof(struct ibv_send_wr));
struct ibv_sge* sge = comm->sges + r;
sge->addr = (uintptr_t)reqs[r]->send.data;
sge->lkey = reqs[r]->send.lkey;
wr->opcode = IBV_WR_RDMA_WRITE;
wr->send_flags = 0;
wr->wr.rdma.remote_addr = slots[r].addr;
wr->wr.rdma.rkey = slots[r].rkey;
wr->next = wr + 1;
wr_id += (reqs[r] - comm->verbs.reqs) << (r * 8);
}
// Write size as immediate data. In the case of multi-send, only write
// 0 or 1 as size to indicate whether there was data sent or received.
uint32_t immData = 0;
if(nreqs == 1) {
immData = reqs[0]->send.size;
} else {
if(nreqs > 32) {
WARN("Cannot store sizes of %d requests in a 32-bits field", nreqs);
return scclInternalError;
}
for(int r = 0; r < nreqs; r++) {
immData |= (reqs[r]->send.size ? 1 : 0) << r;
}
}
struct ibv_send_wr* lastWr = comm->wrs + nreqs - 1;
if(nreqs > 1 || (comm->ar && reqs[0]->send.size > scclParamIbArThreshold())) {
// When using ADAPTIVE_ROUTING, send the bulk of the data first as an
// RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote
// completion.
lastWr++;
memset(lastWr, 0, sizeof(struct ibv_send_wr));
}
lastWr->wr_id = wr_id;
lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
lastWr->imm_data = immData;
lastWr->next = NULL;
lastWr->send_flags = IBV_SEND_SIGNALED;
// Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work
const int align = 128;
const int nqps = scclParamIbSplitDataOnQps() ? comm->nqps : 1;
for(int q = 0; q < nqps; q++) {
for(int r = 0; r < nreqs; r++) {
int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align;
int length = std::min(reqs[r]->send.size - reqs[r]->send.offset, chunkSize);
if(length <= 0) {
comm->wrs[r].sg_list = NULL;
comm->wrs[r].num_sge = 0;
} else {
comm->sges[r].length = length;
comm->wrs[r].sg_list = comm->sges + r;
comm->wrs[r].num_sge = 1;
}
}
struct ibv_send_wr* bad_wr;
SCCLCHECK(wrap_ibv_post_send(comm->qps[comm->qpIndex], comm->wrs, &bad_wr));
comm->qpIndex = (comm->qpIndex + 1) % comm->nqps;
for(int r = 0; r < nreqs; r++) {
int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align;
reqs[r]->send.offset += chunkSize;
comm->sges[r].addr += chunkSize;
comm->wrs[r].wr.rdma.remote_addr += chunkSize;
}
}
return scclSuccess;
}
scclResult_t scclIbPostFifo(struct scclIbRecvComm* comm, int n, void** data, int* sizes, int* tags, void** mhandles, struct scclIbRequest* req) {
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
int slot = comm->remFifo.fifoTail % MAX_REQUESTS;
struct scclIbSendFifo* localElem = comm->remFifo.elems[slot];
for(int i = 0; i < n; i++) {
localElem[i].addr = (uint64_t)data[i];
struct ibv_mr* mr = (struct ibv_mr*)mhandles[i];
localElem[i].rkey = mr->rkey;
localElem[i].nreqs = n;
localElem[i].size = sizes[i]; // Sanity/Debugging
localElem[i].tag = tags[i];
localElem[i].idx = comm->remFifo.fifoTail + 1;
}
wr.wr.rdma.remote_addr = comm->remFifo.addr + slot * SCCL_NET_IB_MAX_RECVS * sizeof(struct scclIbSendFifo);
wr.wr.rdma.rkey = comm->remFifo.rkey;
comm->remFifo.sge.addr = (uint64_t)localElem;
comm->remFifo.sge.length = n * sizeof(struct scclIbSendFifo);
wr.sg_list = &comm->remFifo.sge;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE;
wr.send_flags = comm->remFifo.flags; // IBV_SEND_INLINE
// We need to occasionally post a request with the IBV_SEND_SIGNALED flag, otherwise
// the send queue will never empty.
//
// From https://www.rdmamojo.com/2014/06/30/working-unsignaled-completions/
// "How to use Unsignaled Completion?" / "Gotchas and Pitfalls"
// All posted Send Requested, Signaled and Unsignaled, are considered outstanding until
// a Work Completion that they, or Send Requests that were posted after them, was polled
// from the Completion Queue associated with the Send Queue. This means if one works with
// a Queue Pair that was configured to work with Unsignaled Completions, he must make
// sure that occasionally (before the Send Queue is full with outstanding Send Requests)
// a Send Request that generate Work Completion will be posted.
//
// Not following this rule may lead to a case that the Send Queue is full with Send
// Requests that won't generate Work Completion:
//
// - The Send Queue is full, so no new Send Requests can be posted to it
// - The Send Queue can't be emptied, since no Work Completion can be generated anymore
// (the reason is that no Work Completion, that can generate Work Completion that
// polling it will empty the Send Queue, can be posted)
// - The status of all posted Send Request is considered unknown
//
if(slot == 0) {
wr.send_flags |= IBV_SEND_SIGNALED;
wr.wr_id = req - comm->verbs.reqs;
req->events++;
}
struct ibv_send_wr* bad_wr;
SCCLCHECK(wrap_ibv_post_send(comm->qps[0], &wr, &bad_wr));
comm->remFifo.fifoTail++;
return scclSuccess;
}
} // namespace net_ib
//////////////////////////////////////// scclNetIb调用的函数 ////////////////////////////////////////
namespace net_ib {
/**
* @brief 初始化InfiniBand硬件设备
*
* 该函数负责检测和初始化可用的InfiniBand设备,包括:
* - 加载IB Verbs符号
* - 检测网络接口
* - 查询设备属性
* - 处理用户指定的HCA设备
* - 创建异步线程处理IB事件
*
* @return scclResult_t 返回操作状态,scclSuccess表示成功,scclInternalError表示失败
*
* @note 函数内部会处理环境变量SCCL_IB_HCA来过滤特定设备
* @note 使用互斥锁scclIbLock保证线程安全
*/
scclResult_t scclIbInit(void) {
// 如果IB被禁用,返回内部错误
if(scclParamIbDisable())
return scclInternalError;
// 尝试初始化包装IB符号,如果失败返回内部错误
if(wrap_ibv_symbols() != scclSuccess) {
return scclInternalError;
} else {
INFO(SCCL_LOG_NET, "SCCL IB init done");
}
static int shownIbHcaEnv = 0;
// 如果IB设备数量未初始化,开始初始化过程
if(scclNIbDevs == -1) {
pthread_mutex_lock(&scclIbLock);
wrap_ibv_fork_init();
if(scclNIbDevs == -1) {
scclNIbDevs = 0;
// 查找网络接口
if(host::scclFindSocketInterfaces(scclIbIfName, &scclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) {
WARN("NET/IB : No IP interface found.");
return scclInternalError;
}
// 检测IB卡
int nIbDevs;
struct ibv_device** devices;
struct netIf userIfs[MAX_IB_DEVS];
bool searchNot, searchExact;
// 获取用户指定的IB HCA(InfiniBand Host Channel Adapter)环境变量
char* userIbEnv = scclIbGetIbHca(shownIbHcaEnv, &searchNot, &searchExact);
// 解析用户指定的IB接口列表,将结果存储在userIfs数组中,最多解析MAX_IB_DEVS个接口
int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS);
// 获取设备列表
if(scclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs))
return scclInternalError;
// 遍历所有设备
for(int d = 0; d < nIbDevs && scclNIbDevs < MAX_IB_DEVS; d++) {
struct ibv_context* context;
// 尝试打开设备
if(scclSuccess != wrap_ibv_open_device(&context, devices[d]) || context == NULL) {
WARN("NET/IB : Unable to open device %s", devices[d]->name);
continue;
}
int nPorts = 0;
struct ibv_device_attr devAttr;
memset(&devAttr, 0, sizeof(devAttr));
// 查询设备属性
if(scclSuccess != wrap_ibv_query_device(context, &devAttr)) {
WARN("NET/IB : Unable to query device %s", devices[d]->name);
if(scclSuccess != wrap_ibv_close_device(context)) {
return scclInternalError;
}
continue;
}
// 遍历设备的所有端口
for(int port = 1; port <= devAttr.phys_port_cnt; port++) {
struct ibv_port_attr portAttr;
// 查询端口属性
if(scclSuccess != wrap_ibv_query_port(context, port, &portAttr)) {
WARN("NET/IB : Unable to query port %d", port);
continue;
}
// 检查端口状态和链接层
if(portAttr.state != IBV_PORT_ACTIVE)
continue;
if(portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET)
continue;
// 检查用户指定的HCA/端口
if(!(matchIfList(devices[d]->name, port, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
INFO(SCCL_LOG_NET,
"NET/IB: [%d] %s: port=%d/IB=%s, speed:%d/%d",
d,
devices[d]->name,
port,
portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE",
scclIbSpeed(portAttr.active_speed),
scclIbSpeed(portAttr.active_speed) * scclIbWidth(portAttr.active_width));
pthread_mutex_init(&scclIbDevs[scclNIbDevs].lock, NULL);
INFO(SCCL_LOG_NET, "d=%d, node_guid=%llu, sys_image_guid=%llu\n", d, devAttr.node_guid, devAttr.sys_image_guid);
// 设置Infiniband设备的属性
{
scclIbDevs[scclNIbDevs].device = d; // 设备索引
scclIbDevs[scclNIbDevs].guid = devAttr.sys_image_guid; // 系统图像GUID
scclIbDevs[scclNIbDevs].port = port; // 端口编号
scclIbDevs[scclNIbDevs].link = portAttr.link_layer; // 链路层类型
scclIbDevs[scclNIbDevs].speed = scclIbSpeed(portAttr.active_speed) * scclIbWidth(portAttr.active_width); // 计算设备速度
scclIbDevs[scclNIbDevs].context = context; // 设备上下文
scclIbDevs[scclNIbDevs].pdRefs = 0; // 保护域引用计数
scclIbDevs[scclNIbDevs].pd = NULL; // 保护域指针
strncpy(scclIbDevs[scclNIbDevs].devName, devices[d]->name, MAXNAMESIZE); // 复制设备名称
SCCLCHECK(scclIbGetPciPath(
scclIbDevs[scclNIbDevs].devName, &scclIbDevs[scclNIbDevs].pciPath, &scclIbDevs[scclNIbDevs].realPort)); // 获取PCI路径和实际端口
scclIbDevs[scclNIbDevs].maxQp = devAttr.max_qp; // 最大队列对数量
scclIbDevs[scclNIbDevs].mrCache.capacity = 0; // MR缓存容量
scclIbDevs[scclNIbDevs].mrCache.population = 0; // MR缓存人口
scclIbDevs[scclNIbDevs].mrCache.slots = NULL; // MR缓存槽
// 默认在IB网络上启用ADAPTIVE_ROUTING,但允许通过环境参数覆盖
scclIbDevs[scclNIbDevs].ar = (portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND) ? 1 : 0; // 根据链路层类型设置自适应路由
if(scclParamIbAdaptiveRouting() != -2)
scclIbDevs[scclNIbDevs].ar = scclParamIbAdaptiveRouting(); // 如果环境参数设置,则覆盖默认值
}
// 创建一个新的线程,用于处理SCCL Infiniband的异步操作
pthread_create(&scclIbAsyncThread, NULL, scclIbAsyncThreadMain, context);
// 设置新创建线程的名称,以便于调试和识别
scclSetThreadName(scclIbAsyncThread, "SCCL IbAsync %2d", scclNIbDevs);
// 分离线程,使其在完成后自动回收资源,不需要调用pthread_join()
pthread_detach(scclIbAsyncThread);
scclNIbDevs++; // 增加Infiniband设备的计数
nPorts++; // 增加端口计数
// 再次调用pthread_detach,这行代码可能是多余的,需检查是否为误写
pthread_detach(scclIbAsyncThread);
}
// 如果没有活动端口,关闭设备
if(nPorts == 0 && scclSuccess != wrap_ibv_close_device(context)) {
return scclInternalError;
}
}
// 释放设备列表
if(nIbDevs && (scclSuccess != wrap_ibv_free_device_list(devices))) {
return scclInternalError;
};
}
// 如果没有找到设备,打印信息
if(scclNIbDevs == 0) {
WARN("NET/IB : No device found.");
} else {
char line[1024];
line[0] = '\0';
// 确定是否启用了RELAXED_ORDERING
scclIbRelaxedOrderingEnabled = scclIbRelaxedOrderingCapable();
for(int d = 0; d < scclNIbDevs; d++) {
snprintf(line + strlen(line),
1023 - strlen(line),
" -- [%d]%s:%d/%s; ",
d,
scclIbDevs[d].devName,
scclIbDevs[d].port,
scclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE");
}
// 确保line字符串以null字符结尾,防止字符串操作时出现未定义行为
line[1023] = '\0';
// 定义一个字符数组addrline,用于存储转换后的地址字符串
char addrline[SOCKET_NAME_MAXLEN + 1];
// 记录日志信息,描述当前网络/IB设备的配置和状态
// line 是设备的相关信息字符串
// scclIbRelaxedOrderingEnabled 是一个布尔值,指示是否启用了Relaxed Ordering
// scclIbIfName 是IB接口的名称
// host::scclSocketToString 是一个函数,用于将socket地址转换为字符串
// addrline 是存储转换后地址字符串的数组
INFO(SCCL_LOG_NET,
"NET/IB : Using%s %s; OOB %s:%s",
line,
scclIbRelaxedOrderingEnabled ? "[RO]" : "",
scclIbIfName,
host::scclSocketToString(&scclIbIfAddr, addrline));
}
pthread_mutex_unlock(&scclIbLock);
}
return scclSuccess;
}
/**
* 获取可用的InfiniBand设备数量
*
* @param ndev [out] 用于存储设备数量的指针
* @return scclResult_t 返回操作结果,scclSuccess表示成功
*/
scclResult_t scclIbGetDevicesNum(int* ndev) {
*ndev = scclNIbDevs;
return scclSuccess;
}
/**
* @brief 获取指定IB设备的网络属性
*
* 该函数用于查询指定InfiniBand设备的各项属性,包括设备名称、PCI路径、GUID、
* 指针支持类型、速度、延迟、端口号、最大通信数和最大接收数等。
*
* @param dev 设备索引
* @param props 用于存储设备属性的结构体指针
* @return scclResult_t 返回操作结果,成功返回scclSuccess
*/
scclResult_t scclIbGetProperties(int dev, scclNetProperties_t* props) {
props->name = scclIbDevs[dev].devName;
props->pciPath = scclIbDevs[dev].pciPath;
props->guid = scclIbDevs[dev].guid;
props->ptrSupport = SCCL_PTR_HOST;
if(scclIbGdrSupport(dev) == scclSuccess) {
props->ptrSupport |= SCCL_PTR_CUDA; // GDR support via nv_peermem
}
if(scclIbDmaBufSupport(dev) == scclSuccess) {
props->ptrSupport |= SCCL_PTR_DMABUF; // GDR support via DMA-BUF
}
props->speed = scclIbDevs[dev].speed;
props->latency = 0; // Not set
props->port = scclIbDevs[dev].port + scclIbDevs[dev].realPort;
props->maxComms = scclIbDevs[dev].maxQp;
props->maxRecvs = SCCL_NET_IB_MAX_RECVS;
return scclSuccess;
}
/**
* @brief 在指定设备上创建并初始化IB监听通信
*
* @param dev 设备号
* @param opaqueHandle 不透明的句柄指针,用于存储连接信息
* @param listenComm 返回的监听通信结构体指针
* @return scclResult_t 返回操作结果状态码
*
* 该函数会:
* 1. 分配并初始化监听通信结构体
* 2. 设置设备号和魔法数
* 3. 根据配置决定是否复用套接字
* 4. 启动套接字监听并获取连接地址
*/
scclResult_t scclIbListen(int dev, void* opaqueHandle, void** listenComm) {
// 创建并初始化通信结构体
struct scclIbListenComm* comm;
SCCLCHECK(scclCalloc(&comm, 1));
struct scclIbHandle* handle = (struct scclIbHandle*)opaqueHandle;
static_assert(sizeof(struct scclIbHandle) < SCCL_NET_HANDLE_MAXSIZE, "scclIbHandle size too large");
memset(handle, 0, sizeof(struct scclIbHandle));
// 设置设备和处理句柄
comm->dev = dev;
handle->magic = SCCL_SOCKET_MAGIC;
SCCLCHECK(host::scclSocketInit(&comm->sock, &scclIbIfAddr, handle->magic, host::scclSocketTypeNetIb, NULL, 1));
// 如果启用了端口复用,则复用套接字地址和文件描述符
if(scclParamIbSockServerPortReuse()) {
if(reusedSockfd == -1) {
SCCLCHECK(scclSocketListen(&comm->sock));
memcpy(&reusedAddr, &comm->sock.addr, sizeof(union host::scclSocketAddress));
reusedSockfd = comm->sock.fd;
} else {
memcpy(&comm->sock.addr, &reusedAddr, sizeof(union host::scclSocketAddress));
comm->sock.fd = reusedSockfd;
}
} else {
SCCLCHECK(host::scclSocketListen(&comm->sock));
}
// 获取套接字地址并设置监听通信
SCCLCHECK(host::scclSocketGetAddr(&comm->sock, &handle->connectAddr));
*listenComm = comm;
return scclSuccess;
}
scclResult_t scclIbConnect(int dev, void* opaqueHandle, void** sendComm) {
struct scclIbHandle* handle = (struct scclIbHandle*)opaqueHandle;
struct scclIbCommStage* stage = &handle->stage;
struct scclIbSendComm* comm = (struct scclIbSendComm*)stage->comm;
int ready;
*sendComm = NULL;
if(stage->state == scclIbCommStateConnect)
goto ib_connect_check;
if(stage->state == scclIbCommStateSend)
goto ib_send;
if(stage->state == scclIbCommStateConnecting)
goto ib_connect;
if(stage->state == scclIbCommStateConnected)
goto ib_send_ready;
if(stage->state != scclIbCommStateStart) {
WARN("Error: trying to connect already connected sendComm");
return scclInternalError;
}
SCCLCHECK(scclIbMalloc((void**)&comm, sizeof(struct scclIbSendComm)));
SCCLCHECK(host::scclSocketInit(&comm->sock, &handle->connectAddr, handle->magic, host::scclSocketTypeNetIb, NULL, 1));
stage->comm = comm;
stage->state = scclIbCommStateConnect;
SCCLCHECK(host::scclSocketConnect(&comm->sock, scclParamIbSockClientPortReuse()));
ib_connect_check:
/* since scclSocketConnect is async, we must check if connection is complete */
SCCLCHECK(host::scclSocketReady(&comm->sock, &ready));
if(!ready)
return scclSuccess;
// IB Setup
struct ibv_context* ctx;
ctx = scclIbDevs[dev].context;
SCCLCHECK(scclIbInitVerbs(dev, ctx, &comm->verbs));
uint8_t ib_port;
ib_port = scclIbDevs[dev].port;
comm->nqps = scclParamIbQpsPerConn();
for(int q = 0; q < comm->nqps; q++) {
SCCLCHECK(scclIbCreateQp(ib_port, &comm->verbs, IBV_ACCESS_REMOTE_WRITE, comm->qps + q));
}
comm->ar = scclIbDevs[dev].ar; // ADAPTIVE_ROUTING
// Send my QP Info to receiver through the socket. Hope this won't block.
struct ibv_port_attr portAttr;
SCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr));
struct scclIbQpInfo qpInfo;
qpInfo.ib_port = ib_port;
for(int q = 0; q < comm->nqps; q++)
qpInfo.qpn[q] = comm->qps[q]->qp_num;
qpInfo.mtu = portAttr.active_mtu;
// Prepare my fifo
SCCLCHECK(wrap_ibv_reg_mr(&comm->fifoMr,
comm->verbs.pd,
comm->fifo,
sizeof(struct scclIbSendFifo) * MAX_REQUESTS * SCCL_NET_IB_MAX_RECVS,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ));
qpInfo.fifoRkey = comm->fifoMr->rkey;
qpInfo.fifoAddr = (uint64_t)comm->fifo;
// RoCE support
qpInfo.lid = portAttr.lid;
qpInfo.link_layer = comm->gidInfo.link_layer = portAttr.link_layer;
if(qpInfo.link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB
for(int q = 0; q < comm->nqps; q++)
INFO(SCCL_LOG_NET, "NET/IB: Dev %d Port %d qpn %d mtu %d LID %d", dev, ib_port, qpInfo.qpn[q], qpInfo.mtu, qpInfo.lid);
} else { // RoCE
SCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, scclParamIbGidIndex(), &comm->gidInfo.localGid));
qpInfo.spn = comm->gidInfo.localGid.global.subnet_prefix;
qpInfo.iid = comm->gidInfo.localGid.global.interface_id;
for(int q = 0; q < comm->nqps; q++)
INFO(SCCL_LOG_NET,
"NET/IB: Dev %d Port %d qpn %d mtu %d GID %ld (%lX/%lX)",
dev,
ib_port,
qpInfo.qpn[q],
qpInfo.mtu,
scclParamIbGidIndex(),
qpInfo.spn,
qpInfo.iid);
}
stage->state = scclIbCommStateSend;
stage->offset = 0;
SCCLCHECK(scclIbMalloc((void**)&stage->buffer, sizeof(qpInfo)));
memcpy(stage->buffer, &qpInfo, sizeof(qpInfo));
ib_send:
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_SEND, &comm->sock, stage->buffer, sizeof(qpInfo), &stage->offset));
if(stage->offset != sizeof(qpInfo))
return scclSuccess;
stage->state = scclIbCommStateConnecting;
stage->offset = 0;
// Clear the staging buffer for re-use
memset(stage->buffer, 0, sizeof(qpInfo));
ib_connect:
struct scclIbQpInfo remQpInfo;
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, &comm->sock, stage->buffer, sizeof(scclIbQpInfo), &stage->offset));
if(stage->offset != sizeof(remQpInfo))
return scclSuccess;
memcpy(&remQpInfo, stage->buffer, sizeof(scclIbQpInfo));
comm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn;
comm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid;
for(int q = 0; q < comm->nqps; q++) {
struct ibv_qp* qp = comm->qps[q];
SCCLCHECK(scclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
SCCLCHECK(scclIbRtsQp(qp));
}
comm->ready = 1;
stage->state = scclIbCommStateConnected;
stage->offset = 0;
ib_send_ready:
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_SEND, &comm->sock, &comm->ready, sizeof(int), &stage->offset));
if(stage->offset != sizeof(int))
return scclSuccess;
free(stage->buffer);
stage->state = scclIbCommStateStart;
*sendComm = comm;
return scclSuccess;
}
/**
* @brief 接受IB连接请求并建立通信通道
*
* 该函数处理IB连接的接受过程,包括以下步骤:
* 1. 初始化接收通信结构体
* 2. 接受socket连接
* 3. 交换QP信息
* 4. 创建并配置QP队列
* 5. 设置远程FIFO信息
* 6. 处理GPU直接RDMA刷新缓冲区
* 7. 完成握手过程
*
* @param listenComm 监听通信句柄
* @param recvComm 输出参数,接收通信句柄
* @return scclResult_t 返回操作结果,成功返回scclSuccess
*/
scclResult_t scclIbAccept(void* listenComm, void** recvComm) {
struct scclIbListenComm* lComm = (struct scclIbListenComm*)listenComm;
struct scclIbCommStage* stage = &lComm->stage;
struct scclIbRecvComm* rComm = (struct scclIbRecvComm*)stage->comm;
int ready;
*recvComm = NULL;
if(stage->state == scclIbCommStateAccept)
goto ib_accept_check;
if(stage->state == scclIbCommStateRecv)
goto ib_recv;
if(stage->state == scclIbCommStateSend)
goto ib_send;
if(stage->state == scclIbCommStatePendingReady)
goto ib_recv_ready;
if(stage->state != scclIbCommStateStart) {
WARN("Listencomm in unknown state %d", stage->state);
return scclInternalError;
}
SCCLCHECK(scclIbMalloc((void**)&rComm, sizeof(struct scclIbRecvComm)));
stage->comm = rComm;
stage->state = scclIbCommStateAccept;
SCCLCHECK(host::scclSocketInit(&rComm->sock));
SCCLCHECK(host::scclSocketAccept(&rComm->sock, &lComm->sock));
ib_accept_check:
SCCLCHECK(host::scclSocketReady(&rComm->sock, &ready));
if(!ready)
return scclSuccess;
struct scclIbQpInfo remQpInfo;
stage->state = scclIbCommStateRecv;
stage->offset = 0;
SCCLCHECK(scclIbMalloc((void**)&stage->buffer, sizeof(remQpInfo)));
ib_recv:
SCCLCHECK(host::scclSocketProgress(SCCL_SOCKET_RECV, &rComm->sock, stage->buffer, sizeof(remQpInfo), &stage->offset));
if(stage->offset != sizeof(remQpInfo))
return scclSuccess;
/* copy back the received info */
memcpy(&remQpInfo, stage->buffer, sizeof(struct scclIbQpInfo));
rComm->gidInfo.remoteGid.global.subnet_prefix = remQpInfo.spn;
rComm->gidInfo.remoteGid.global.interface_id = remQpInfo.iid;
// IB setup
struct ibv_context* ctx;
uint8_t ib_port;
ctx = scclIbDevs[lComm->dev].context;
ib_port = scclIbDevs[lComm->dev].port;
struct ibv_port_attr portAttr;
SCCLCHECK(wrap_ibv_query_port(ctx, ib_port, &portAttr));
SCCLCHECK(wrap_ibv_query_gid(ctx, ib_port, scclParamIbGidIndex(), &rComm->gidInfo.localGid));
// QP Creation
SCCLCHECK(scclIbInitVerbs(lComm->dev, ctx, &rComm->verbs));
rComm->nqps = scclParamIbQpsPerConn();
for(int q = 0; q < rComm->nqps; q++) {
SCCLCHECK(scclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_REMOTE_WRITE, rComm->qps + q));
}
// Adjust the MTU
remQpInfo.mtu = (enum ibv_mtu)std::min(remQpInfo.mtu, portAttr.active_mtu);
// Setup QP
for(int q = 0; q < rComm->nqps; q++) {
struct ibv_qp* qp = rComm->qps[q];
SCCLCHECK(scclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo));
SCCLCHECK(scclIbRtsQp(qp));
}
// Retain remote fifo info and prepare my RDMA ops
rComm->remFifo.rkey = remQpInfo.fifoRkey;
rComm->remFifo.addr = remQpInfo.fifoAddr;
SCCLCHECK(wrap_ibv_reg_mr(&rComm->remFifo.mr,
rComm->verbs.pd,
&rComm->remFifo.elems,
sizeof(struct scclIbSendFifo) * MAX_REQUESTS * SCCL_NET_IB_MAX_RECVS,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ));
rComm->remFifo.sge.lkey = rComm->remFifo.mr->lkey;
if(scclParamIbUseInline())
rComm->remFifo.flags = IBV_SEND_INLINE;
// Allocate Flush dummy buffer for GPU Direct RDMA
rComm->gpuFlush.enabled =
((scclIbGdrSupport(lComm->dev) == scclSuccess || scclIbDmaBufSupport(lComm->dev) == scclSuccess) && (scclParamIbGdrFlushDisable() == 0)) ? 1 : 0;
if(rComm->gpuFlush.enabled) {
SCCLCHECK(wrap_ibv_reg_mr(&rComm->gpuFlush.hostMr, rComm->verbs.pd, &rComm->gpuFlush.hostMem, sizeof(int), IBV_ACCESS_LOCAL_WRITE));
rComm->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlush.hostMem;
rComm->gpuFlush.sge.length = 1;
rComm->gpuFlush.sge.lkey = rComm->gpuFlush.hostMr->lkey;
SCCLCHECK(scclIbCreateQp(ib_port, &rComm->verbs, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ, &rComm->gpuFlush.qp));
struct scclIbQpInfo localQpInfo;
localQpInfo.lid = portAttr.lid;
localQpInfo.link_layer = portAttr.link_layer;
localQpInfo.ib_port = ib_port;
localQpInfo.spn = rComm->gidInfo.localGid.global.subnet_prefix;
localQpInfo.iid = rComm->gidInfo.localGid.global.interface_id;
localQpInfo.mtu = portAttr.active_mtu;
SCCLCHECK(scclIbRtrQp(rComm->gpuFlush.qp, rComm->gpuFlush.qp->qp_num, &localQpInfo));
SCCLCHECK(scclIbRtsQp(rComm->gpuFlush.qp));
}
// Fill Handle
struct scclIbQpInfo qpInfo;
qpInfo.lid = portAttr.lid;
qpInfo.link_layer = rComm->gidInfo.link_layer = portAttr.link_layer;
qpInfo.ib_port = ib_port;
for(int q = 0; q < rComm->nqps; q++)
qpInfo.qpn[q] = rComm->qps[q]->qp_num;
qpInfo.spn = rComm->gidInfo.localGid.global.subnet_prefix;
qpInfo.iid = rComm->gidInfo.localGid.global.interface_id;
qpInfo.mtu = remQpInfo.mtu;
stage->state = scclIbCommStateSend;
stage->offset = 0;
if(stage->buffer)
free(stage->buffer);
SCCLCHECK(scclIbMalloc((void**)&stage->buffer, sizeof(struct scclIbQpInfo)));
memcpy(stage->buffer, &qpInfo, sizeof(struct scclIbQpInfo));
ib_send:
SCCLCHECK(host::scclSocketProgress(SCCL_SOCKET_SEND, &rComm->sock, stage->buffer, sizeof(struct scclIbQpInfo), &stage->offset));
if(stage->offset < sizeof(struct scclIbQpInfo))
return scclSuccess;
stage->offset = 0;
stage->state = scclIbCommStatePendingReady;
ib_recv_ready:
SCCLCHECK(host::scclSocketProgress(SCCL_SOCKET_RECV, &rComm->sock, &rComm->ready, sizeof(int), &stage->offset));
if(stage->offset != sizeof(int))
return scclSuccess;
free(stage->buffer);
*recvComm = rComm;
/* reset lComm stage */
stage->state = scclIbCommStateStart;
stage->offset = 0;
stage->comm = NULL;
stage->buffer = NULL;
return scclSuccess;
}
/* DMA-BUF support */
scclResult_t scclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) {
static_assert(offsetof(struct scclIbSendComm, verbs) == offsetof(struct scclIbRecvComm, verbs), "Send and recv comms must have verbs at the same offset");
assert(size > 0);
static __thread uintptr_t pageSize = 0;
if(pageSize == 0)
pageSize = sysconf(_SC_PAGESIZE);
struct scclIbVerbs* verbs = (struct scclIbVerbs*)comm;
struct scclIbMrCache* cache = &scclIbDevs[verbs->dev].mrCache;
uintptr_t addr = (uintptr_t)data & -pageSize;
size_t pages = ((uintptr_t)data + size - addr + pageSize - 1) / pageSize;
scclResult_t res;
pthread_mutex_lock(&scclIbDevs[verbs->dev].lock);
for(int slot = 0; /*true*/; slot++) {
if(slot == cache->population) { // didn't find in cache
if(cache->population == cache->capacity) { // must grow cache
cache->capacity = cache->capacity < 32 ? 32 : 2 * cache->capacity;
SCCLCHECKGOTO(scclRealloc(&cache->slots, cache->population, cache->capacity), res, returning);
}
// Deregister / register
struct ibv_mr* mr;
unsigned int flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
if(scclIbRelaxedOrderingEnabled)
flags |= IBV_ACCESS_RELAXED_ORDERING;
if(fd != -1) {
/* DMA-BUF support */
SCCLCHECKGOTO(wrap_ibv_reg_dmabuf_mr(&mr, verbs->pd, offset, pages * pageSize, addr, fd, flags), res, returning);
} else {
if(scclIbRelaxedOrderingEnabled) {
// Use IBVERBS_1.8 API - needed for IBV_ACCESS_RELAXED_ORDERING support
SCCLCHECKGOTO(wrap_ibv_reg_mr_iova2(&mr, verbs->pd, (void*)addr, pages * pageSize, addr, flags), res, returning);
} else {
SCCLCHECKGOTO(wrap_ibv_reg_mr(&mr, verbs->pd, (void*)addr, pages * pageSize, flags), res, returning);
}
}
INFO(SCCL_LOG_NET, "regAddr %llx size %lld rkey %x fd %d", (unsigned long long)addr, (long long)pages * pageSize, mr->rkey, fd);
cache->population += 1;
cache->slots[slot].addr = addr;
cache->slots[slot].pages = pages;
cache->slots[slot].refs = 1;
cache->slots[slot].mr = mr;
*mhandle = (void*)mr;
res = scclSuccess;
goto returning;
} else if(cache->slots[slot].addr == addr && cache->slots[slot].pages == pages) {
cache->slots[slot].refs += 1;
*mhandle = (void*)cache->slots[slot].mr;
res = scclSuccess;
goto returning;
}
}
returning:
pthread_mutex_unlock(&scclIbDevs[verbs->dev].lock);
return res;
}
scclResult_t scclIbRegMr(void* comm, void* data, int size, int type, void** mhandle) {
return scclIbRegMrDmaBuf(comm, data, (size_t)size, type, 0ULL, -1, mhandle);
}
scclResult_t scclIbDeregMr(void* comm, void* mhandle) {
struct scclIbVerbs* verbs = (struct scclIbVerbs*)comm;
struct scclIbMrCache* cache = &scclIbDevs[verbs->dev].mrCache;
scclResult_t res;
pthread_mutex_lock(&scclIbDevs[verbs->dev].lock);
for(int i = 0; i < cache->population; i++) {
if(mhandle == cache->slots[i].mr) {
if(0 == --cache->slots[i].refs) {
memmove(&cache->slots[i], &cache->slots[--cache->population], sizeof(struct scclIbMr));
if(cache->population == 0) {
free(cache->slots);
cache->slots = NULL;
cache->capacity = 0;
}
SCCLCHECKGOTO(wrap_ibv_dereg_mr((struct ibv_mr*)mhandle), res, returning);
}
res = scclSuccess;
goto returning;
}
}
WARN("NET/IB: could not find mr %p inside cache of %d entries", mhandle, cache->population);
res = scclInternalError;
returning:
pthread_mutex_unlock(&scclIbDevs[verbs->dev].lock);
return res;
}
scclResult_t scclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) {
struct scclIbSendComm* comm = (struct scclIbSendComm*)sendComm;
if(comm->ready == 0) {
WARN("NET/IB: scclIbIsend() called when comm->ready == 0");
return scclInternalError;
}
if(comm->ready == 0) {
*request = NULL;
return scclSuccess;
}
struct ibv_mr* mr = (struct ibv_mr*)mhandle;
// Wait for the receiver to have posted the corresponding receive
int nreqs = 0;
volatile struct scclIbSendFifo* slots;
int slot = (comm->fifoHead) % MAX_REQUESTS;
struct scclIbRequest** reqs = comm->fifoReqs[slot];
slots = comm->fifo[slot];
uint64_t idx = comm->fifoHead + 1;
if(slots[0].idx != idx) {
*request = NULL;
return scclSuccess;
}
nreqs = slots[0].nreqs;
// Wait until all data has arrived
for(int r = 1; r < nreqs; r++)
while(slots[r].idx != idx)
;
__sync_synchronize(); // order the nreqsPtr load against tag/rkey/addr loads below
for(int r = 0; r < nreqs; r++) {
if(reqs[r] != NULL || slots[r].tag != tag)
continue;
// Sanity checks to catch user collective call count/size mismatches
if(size > slots[r].size) {
char line[SOCKET_NAME_MAXLEN + 1];
union host::scclSocketAddress addr;
host::scclSocketGetAddr(&comm->sock, &addr);
WARN("NET/IB : req %d/%d tag %x peer %s collective mismatch error, local size %d remote size %d",
r,
nreqs,
tag,
host::scclSocketToString(&addr, line),
size,
slots[r].size);
return scclInvalidUsage;
} // plus any potential programming errors
else if(slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) {
char line[SOCKET_NAME_MAXLEN + 1];
union host::scclSocketAddress addr;
host::scclSocketGetAddr(&comm->sock, &addr);
WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkey %x",
r,
nreqs,
tag,
host::scclSocketToString(&addr, line),
slots[r].size,
slots[r].addr,
slots[r].rkey);
return scclInternalError;
}
struct scclIbRequest* req;
SCCLCHECK(scclIbGetRequest(&comm->verbs, &req));
req->type = SCCL_NET_IB_REQ_SEND;
req->sock = &comm->sock;
req->verbs = &comm->verbs;
req->nreqs = nreqs;
req->send.size = size;
req->send.data = data;
req->send.lkey = mr->lkey;
req->send.offset = 0;
req->events = scclParamIbSplitDataOnQps() ? comm->nqps : 1;
if(comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET)
req->gidInfo = &comm->gidInfo;
*request = reqs[r] = req;
// If this is a multi-recv, send only when all requests have matched.
for(int r = 0; r < nreqs; r++) {
if(reqs[r] == NULL)
return scclSuccess;
}
SCCLCHECK(scclIbMultiSend(comm, slot));
// Clear slots[0]->nreqs, as well as other fields to help debugging and sanity checks
memset((void*)slots, 0, sizeof(struct scclIbSendFifo));
memset(reqs, 0, SCCL_NET_IB_MAX_RECVS * sizeof(struct scclIbRequest*));
comm->fifoHead++;
return scclSuccess;
}
*request = NULL;
return scclSuccess;
}
scclResult_t scclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) {
struct scclIbRecvComm* comm = (struct scclIbRecvComm*)recvComm;
if(comm->ready == 0) {
WARN("NET/IB: scclIbIrecv() called when comm->ready == 0");
return scclInternalError;
}
if(comm->ready == 0) {
*request = NULL;
return scclSuccess;
}
if(n > SCCL_NET_IB_MAX_RECVS)
return scclInternalError;
struct scclIbRequest* req;
SCCLCHECK(scclIbGetRequest(&comm->verbs, &req));
req->type = SCCL_NET_IB_REQ_RECV;
req->sock = &comm->sock;
req->nreqs = n;
if(comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET)
req->gidInfo = &comm->gidInfo;
for(int i = 0; i < n; i++)
req->recv.sizes[i] = 0;
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
wr.wr_id = req - comm->verbs.reqs;
wr.sg_list = NULL;
wr.num_sge = 0;
const int nqps = scclParamIbSplitDataOnQps() ? comm->nqps : 1;
for(int q = 0; q < nqps; q++) {
struct ibv_qp* qp = comm->qps[comm->qpIndex];
struct ibv_recv_wr* bad_wr;
SCCLCHECK(wrap_ibv_post_recv(qp, &wr, &bad_wr));
comm->qpIndex = (comm->qpIndex + 1) % comm->nqps;
}
req->events = nqps;
*request = req;
// Post to FIFO to notify sender
SCCLCHECK(scclIbPostFifo(comm, n, data, sizes, tags, mhandles, req));
return scclSuccess;
}
scclResult_t scclIbIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) {
struct scclIbRecvComm* comm = (struct scclIbRecvComm*)recvComm;
int last = -1;
for(int i = 0; i < n; i++)
if(sizes[i])
last = i;
if(comm->gpuFlush.enabled == 0 || last == -1)
return scclSuccess;
// Only flush once using the last non-zero receive
struct scclIbRequest* req;
SCCLCHECK(scclIbGetRequest(&comm->verbs, &req));
req->type = SCCL_NET_IB_REQ_FLUSH;
req->sock = &comm->sock;
struct ibv_mr* mr = (struct ibv_mr*)mhandles[last];
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
wr.wr_id = req - comm->verbs.reqs;
wr.wr.rdma.remote_addr = (uint64_t)data[last];
wr.wr.rdma.rkey = mr->rkey;
wr.sg_list = &comm->gpuFlush.sge;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_READ;
wr.send_flags = IBV_SEND_SIGNALED;
struct ibv_send_wr* bad_wr;
SCCLCHECK(wrap_ibv_post_send(comm->gpuFlush.qp, &wr, &bad_wr));
*request = req;
return scclSuccess;
}
scclResult_t scclIbTest(void* request, int* done, int* sizes) {
struct scclIbRequest* r = (struct scclIbRequest*)request;
*done = 0;
while(1) {
if(r->events == 0) {
*done = 1;
if(sizes && r->type == SCCL_NET_IB_REQ_RECV) {
for(int i = 0; i < r->nreqs; i++)
sizes[i] = r->recv.sizes[i];
}
SCCLCHECK(scclIbFreeRequest(r));
return scclSuccess;
}
int wrDone = 0;
struct ibv_wc wcs[4];
SCCLCHECK(wrap_ibv_poll_cq(r->verbs->cq, 4, wcs, &wrDone));
if(wrDone == 0)
return scclSuccess;
for(int w = 0; w < wrDone; w++) {
struct ibv_wc* wc = wcs + w;
if(wc->status != IBV_WC_SUCCESS) {
char line[SOCKET_NAME_MAXLEN + 1];
union host::scclSocketAddress addr;
host::scclSocketGetAddr(r->sock, &addr);
char localGidString[INET6_ADDRSTRLEN] = "";
char remoteGidString[INET6_ADDRSTRLEN] = "";
const char *localGidStr = NULL, *remoteGidStr = NULL;
if(r->gidInfo) {
localGidStr = inet_ntop(AF_INET6, &r->gidInfo->localGid, localGidString, sizeof(localGidString));
remoteGidStr = inet_ntop(AF_INET6, &r->gidInfo->remoteGid, remoteGidString, sizeof(remoteGidString));
}
WARN("NET/IB : Got completion from peer %s with error %d, opcode %d, len %d, vendor err %d (%s)%s%s%s%s",
host::scclSocketToString(&addr, line),
wc->status,
wc->opcode,
wc->byte_len,
wc->vendor_err,
reqTypeStr[r->type],
localGidStr ? " localGid " : "",
localGidString,
remoteGidStr ? " remoteGid " : "",
remoteGidString);
return scclRemoteError;
}
struct scclIbRequest* req = r->verbs->reqs + (wc->wr_id & 0xff);
if(req->type == SCCL_NET_IB_REQ_SEND) {
for(int i = 0; i < req->nreqs; i++) {
struct scclIbRequest* sendReq = r->verbs->reqs + ((wc->wr_id >> (i * 8)) & 0xff);
if((sendReq->events <= 0))
return scclInternalError;
sendReq->events--;
}
} else {
if(req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
if(req->type != SCCL_NET_IB_REQ_RECV)
return scclInternalError;
if(req->nreqs > 1) {
// In the case of a multi recv, we only set sizes to 0 or 1.
for(int i = 0; i < req->nreqs; i++) {
req->recv.sizes[i] = (wc->imm_data >> i) & 0x1;
}
} else {
req->recv.sizes[0] += wc->imm_data;
}
}
req->events--;
}
}
}
}
scclResult_t scclIbCloseSend(void* sendComm) {
struct scclIbSendComm* comm = (struct scclIbSendComm*)sendComm;
if(comm) {
SCCLCHECK(host::scclSocketClose(&comm->sock));
for(int q = 0; q < comm->nqps; q++)
if(comm->qps[q] != NULL)
SCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
if(comm->fifoMr != NULL)
SCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr));
SCCLCHECK(scclIbDestroyVerbs(&comm->verbs));
free(comm);
}
return scclSuccess;
}
scclResult_t scclIbCloseRecv(void* recvComm) {
struct scclIbRecvComm* comm = (struct scclIbRecvComm*)recvComm;
if(comm) {
if(!scclParamIbSockServerPortReuse() || reusedSockfd != comm->sock.fd)
SCCLCHECK(host::scclSocketClose(&comm->sock));
for(int q = 0; q < comm->nqps; q++)
if(comm->qps[q] != NULL)
SCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
if(comm->gpuFlush.enabled) {
if(comm->gpuFlush.qp != NULL)
SCCLCHECK(wrap_ibv_destroy_qp(comm->gpuFlush.qp));
if(comm->gpuFlush.hostMr != NULL)
SCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr));
}
if(comm->remFifo.mr != NULL)
SCCLCHECK(wrap_ibv_dereg_mr(comm->remFifo.mr));
SCCLCHECK(scclIbDestroyVerbs(&comm->verbs));
free(comm);
}
return scclSuccess;
}
scclResult_t scclIbCloseListen(void* listenComm) {
struct scclIbListenComm* comm = (struct scclIbListenComm*)listenComm;
if(comm) {
SCCLCHECK(host::scclSocketClose(&comm->sock));
free(comm);
}
return scclSuccess;
}
} // namespace net_ib
scclNet_t scclNetIb = {"IB",
net_ib::scclIbInit,
net_ib::scclIbGetDevicesNum,
net_ib::scclIbGetProperties,
net_ib::scclIbListen,
net_ib::scclIbConnect,
net_ib::scclIbAccept,
net_ib::scclIbRegMr,
net_ib::scclIbRegMrDmaBuf,
net_ib::scclIbDeregMr,
net_ib::scclIbIsend,
net_ib::scclIbIrecv,
net_ib::scclIbIflush,
net_ib::scclIbTest,
net_ib::scclIbCloseSend,
net_ib::scclIbCloseRecv,
net_ib::scclIbCloseListen};
} // namespace device
} // namespace net
} // namespace hardware
} // namespace sccl
#pragma once
#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <sys/types.h>
#include <unistd.h>
#include "ibvwrap.h"
#include "net_utils.h"
namespace sccl {
namespace hardware {
namespace net {
namespace device {
//////////////////////////////////
extern scclNet_t scclNetIb;
} // namespace device
} // namespace net
} // namespace hardware
} // namespace sccl
#include <pthread.h>
#include <stdlib.h>
#include <poll.h>
#include <limits.h>
#include <fcntl.h>
#include "socket.h"
#include "net_socket.h"
namespace sccl {
namespace hardware {
namespace net {
namespace host {
namespace net_socket {
#define MAX_LINE_LEN (2047)
/* Init functions */
static int scclNetIfs = -1;
struct scclNetSocketDev {
union scclSocketAddress addr;
char devName[MAX_IF_NAME_SIZE];
char* pciPath;
};
static struct scclNetSocketDev scclNetSocketDevs[MAX_IFS];
pthread_mutex_t scclNetSocketLock = PTHREAD_MUTEX_INITIALIZER;
static scclResult_t scclNetSocketGetPciPath(char* devName, char** pciPath) {
char devicePath[PATH_MAX];
snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName);
// May return NULL if the file doesn't exist.
*pciPath = realpath(devicePath, NULL);
return scclSuccess;
}
scclResult_t scclNetSocketInit(void) {
if(scclNetIfs == -1) {
pthread_mutex_lock(&scclNetSocketLock);
if(scclNetIfs == -1) {
char names[MAX_IF_NAME_SIZE * MAX_IFS];
union scclSocketAddress addrs[MAX_IFS];
scclNetIfs = scclFindSocketInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS);
if(scclNetIfs <= 0) {
WARN("NET/Socket : no interface found");
return scclInternalError;
} else {
char line[MAX_LINE_LEN + 1];
char addrline[SOCKET_NAME_MAXLEN + 1];
line[0] = '\0';
addrline[SOCKET_NAME_MAXLEN] = '\0';
for(int i = 0; i < scclNetIfs; i++) {
strcpy(scclNetSocketDevs[i].devName, names + i * MAX_IF_NAME_SIZE);
memcpy(&scclNetSocketDevs[i].addr, addrs + i, sizeof(union scclSocketAddress));
SCCLCHECK(scclNetSocketGetPciPath(scclNetSocketDevs[i].devName, &scclNetSocketDevs[i].pciPath));
snprintf(line + strlen(line),
MAX_LINE_LEN - strlen(line),
" [%d]%s:%s",
i,
names + i * MAX_IF_NAME_SIZE,
scclSocketToString(&addrs[i], addrline));
}
line[MAX_LINE_LEN] = '\0';
INFO(SCCL_LOG_NET, "NET/Socket : Using%s", line);
}
}
pthread_mutex_unlock(&scclNetSocketLock);
}
return scclSuccess;
}
scclResult_t scclNetSocketDevices(int* ndev) {
*ndev = scclNetIfs;
return scclSuccess;
}
static scclResult_t scclNetSocketGetSpeed(char* devName, int* speed) {
*speed = 0;
char speedPath[PATH_MAX];
sprintf(speedPath, "/sys/class/net/%s/speed", devName);
int fd = open(speedPath, O_RDONLY);
if(fd != -1) {
char speedStr[] = " ";
if(read(fd, speedStr, sizeof(speedStr) - 1) > 0) {
*speed = strtol(speedStr, NULL, 0);
}
close(fd);
}
if(*speed <= 0) {
INFO(SCCL_LOG_NET, "Could not get speed from %s. Defaulting to 10 Gbps.", speedPath);
*speed = 10000;
}
return scclSuccess;
}
scclResult_t scclNetSocketGetProperties(int dev, scclNetProperties_t* props) {
props->name = scclNetSocketDevs[dev].devName;
props->pciPath = scclNetSocketDevs[dev].pciPath;
props->guid = dev;
props->ptrSupport = SCCL_PTR_HOST;
SCCLCHECK(scclNetSocketGetSpeed(props->name, &props->speed));
props->latency = 0; // Not set
props->port = 0;
props->maxComms = 65536;
props->maxRecvs = 1;
return scclSuccess;
}
/* Communication functions */
#define MAX_SOCKETS 64
#define MAX_THREADS 16
#define MAX_REQUESTS SCCL_NET_MAX_REQUESTS
#define MIN_CHUNKSIZE (64 * 1024)
SCCL_PARAM(SocketNsocksPerThread, "NSOCKS_PERTHREAD", -2);
SCCL_PARAM(SocketNthreads, "SOCKET_NTHREADS", -2);
enum scclNetSocketCommState : uint8_t {
scclNetSocketCommStateStart = 0,
scclNetSocketCommStateConnect = 1,
scclNetSocketCommStateAccept = 3,
scclNetSocketCommStateSend = 4,
scclNetSocketCommStateRecv = 5,
};
struct scclNetSocketCommStage {
enum scclNetSocketCommState state;
uint8_t iteration;
struct scclSocket* sock;
struct scclNetSocketComm* comm;
};
struct scclNetSocketHandle {
union scclSocketAddress connectAddr;
uint64_t magic; // random number to help debugging
int nSocks;
int nThreads;
struct scclNetSocketCommStage stage;
};
struct scclNetSocketTask {
int op;
void* data;
int size;
struct scclSocket* sock;
int offset;
int used;
scclResult_t result;
};
struct scclNetSocketRequest {
int op;
void* data;
int size;
struct scclSocket* ctrlSock;
int offset;
int used;
struct scclNetSocketComm* comm;
struct scclNetSocketTask* tasks[MAX_SOCKETS];
int nSubs;
};
struct scclNetSocketTaskQueue {
int next;
int len;
struct scclNetSocketTask* tasks;
};
struct scclNetSocketThreadResources {
struct scclNetSocketTaskQueue threadTaskQueue;
int stop;
struct scclNetSocketComm* comm;
pthread_mutex_t threadLock;
pthread_cond_t threadCond;
};
struct scclNetSocketListenComm {
struct scclSocket sock;
struct scclNetSocketCommStage stage;
int nSocks;
int nThreads;
int dev;
};
struct scclNetSocketComm {
struct scclSocket ctrlSock;
struct scclSocket socks[MAX_SOCKETS];
int dev;
int cudaDev;
int nSocks;
int nThreads;
int nextSock;
struct scclNetSocketRequest requests[MAX_REQUESTS];
pthread_t helperThread[MAX_THREADS];
struct scclNetSocketThreadResources threadResources[MAX_THREADS];
};
void* persistentSocketThread(void* args_) {
struct scclNetSocketThreadResources* resource = (struct scclNetSocketThreadResources*)args_;
struct scclNetSocketComm* comm = resource->comm;
struct scclNetSocketTaskQueue* myQueue = &resource->threadTaskQueue;
int nSocksPerThread = comm->nSocks / comm->nThreads;
while(1) {
int idle = 1;
int mark = myQueue->next; // mark newest task seen
for(int i = 0; i < myQueue->len; i += nSocksPerThread) {
int repeat;
do {
repeat = 0;
for(int j = 0; j < nSocksPerThread; j++) {
struct scclNetSocketTask* r = myQueue->tasks + i + j;
if(r != NULL && r->used == 1 && r->offset < r->size) {
r->result = scclSocketProgress(r->op, r->sock, r->data, r->size, &r->offset);
if(r->result != scclSuccess) {
WARN("NET/Socket : socket progress error");
return NULL;
}
idle = 0;
if(r->offset < r->size)
repeat = 1;
}
}
} while(repeat);
}
if(idle) {
pthread_mutex_lock(&resource->threadLock);
while(mark == myQueue->next && resource->stop == 0) { // no new tasks, wait
pthread_cond_wait(&resource->threadCond, &resource->threadLock);
}
pthread_mutex_unlock(&resource->threadLock);
}
if(resource->stop)
return NULL;
}
}
scclResult_t scclNetSocketGetNsockNthread(int dev, int* ns, int* nt) {
int nSocksPerThread = scclParamSocketNsocksPerThread();
int nThreads = scclParamSocketNthreads();
if(nThreads > MAX_THREADS) {
WARN("NET/Socket : SCCL_SOCKET_NTHREADS is greater than the maximum allowed, setting to %d", MAX_THREADS);
nThreads = MAX_THREADS;
}
if(nThreads == -2 || nSocksPerThread == -2) {
// Auto-detection
int autoNt = 0, autoNs = 1; // By default, we only use the main thread and do not spawn extra threads
char vendorPath[PATH_MAX];
snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", scclNetSocketDevs[dev].devName);
char* rPath = realpath(vendorPath, NULL);
int fd = open(rPath, O_RDONLY);
free(rPath);
if(fd == -1) {
// Could not find device vendor. This is handled silently so
// we don't want to print an INFO error.
INFO(SCCL_LOG_NET, "Open of %s failed : %s", vendorPath, strerror(errno));
goto end;
}
char vendor[7];
strncpy(vendor, "0x0000", 7);
int len;
SYSCHECKVAL(read(fd, vendor, 6), "read", len);
SYSCHECK(close(fd), "close");
if(strcmp(vendor, "0x1d0f") == 0) { // AWS
autoNt = 2;
autoNs = 8;
} else if(strcmp(vendor, "0x1ae0") == 0) { // GCP
autoNt = 4;
autoNs = 1;
}
end:
if(nThreads == -2)
nThreads = autoNt;
if(nSocksPerThread == -2)
nSocksPerThread = autoNs;
}
int nSocks = nSocksPerThread * nThreads;
if(nSocks > MAX_SOCKETS) {
nSocksPerThread = MAX_SOCKETS / nThreads;
WARN("NET/Socket : the total number of sockets is greater than the maximum allowed, setting SCCL_NSOCKS_PERTHREAD to %d", nSocksPerThread);
nSocks = nSocksPerThread * nThreads;
}
*ns = nSocks;
*nt = nThreads;
if(nSocks > 0)
INFO(SCCL_LOG_NET, "NET/Socket: Using %d threads and %d sockets per thread", nThreads, nSocksPerThread);
return scclSuccess;
}
scclResult_t scclNetSocketListen(int dev, void* opaqueHandle, void** listenComm) {
if(dev < 0 || dev >= scclNetIfs) { // data transfer socket is based on specified dev
return scclInternalError;
}
struct scclNetSocketHandle* handle = (struct scclNetSocketHandle*)opaqueHandle;
memset(handle, 0, sizeof(struct scclNetSocketHandle));
static_assert(sizeof(struct scclNetSocketHandle) <= SCCL_NET_HANDLE_MAXSIZE, "scclNetSocketHandle size too large");
struct scclNetSocketListenComm* comm;
SCCLCHECK(scclCalloc(&comm, 1));
handle->magic = SCCL_SOCKET_MAGIC;
SCCLCHECK(scclSocketInit(&comm->sock, &scclNetSocketDevs[dev].addr, handle->magic, scclSocketTypeNetSocket, NULL, 1));
SCCLCHECK(scclSocketListen(&comm->sock));
SCCLCHECK(scclSocketGetAddr(&comm->sock, &handle->connectAddr));
SCCLCHECK(scclNetSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads));
handle->nSocks = comm->nSocks;
handle->nThreads = comm->nThreads;
comm->dev = dev;
*listenComm = comm;
return scclSuccess;
}
scclResult_t scclNetSocketConnect(int dev, void* opaqueHandle, void** sendComm) {
if(dev < 0 || dev >= scclNetIfs) { // data transfer socket is based on specified dev
return scclInternalError;
}
int ready;
struct scclNetSocketHandle* handle = (struct scclNetSocketHandle*)opaqueHandle;
struct scclNetSocketCommStage* stage = &handle->stage;
struct scclNetSocketComm* comm = stage->comm;
uint8_t i = stage->iteration;
struct scclSocket* sock = stage->sock;
*sendComm = NULL;
if(stage->state == scclNetSocketCommStateConnect)
goto socket_connect_check;
if(stage->state == scclNetSocketCommStateSend)
goto socket_send;
SCCLCHECK(scclCalloc(&comm, 1));
stage->comm = comm;
comm->nSocks = handle->nSocks;
comm->nThreads = handle->nThreads;
comm->dev = dev;
HIPCHECK(hipGetDevice(&comm->cudaDev));
for(; i < comm->nSocks + 1; i++) {
sock = (i == comm->nSocks) ? &comm->ctrlSock : comm->socks + i;
SCCLCHECK(scclSocketInit(sock, &handle->connectAddr, handle->magic, scclSocketTypeNetSocket, NULL, 1));
stage->sock = sock;
stage->state = scclNetSocketCommStateConnect;
stage->iteration = i;
SCCLCHECK(scclSocketConnect(sock));
socket_connect_check:
SCCLCHECK(scclSocketReady(sock, &ready));
if(!ready)
return scclSuccess;
stage->state = scclNetSocketCommStateSend;
socket_send:
int done = 0;
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_SEND, sock, &i, sizeof(uint8_t), &done));
if(done == 0)
return scclSuccess;
}
*sendComm = comm;
return scclSuccess;
}
scclResult_t scclNetSocketAccept(void* listenComm, void** recvComm) {
struct scclNetSocketListenComm* lComm = (struct scclNetSocketListenComm*)listenComm;
struct scclNetSocketCommStage* stage = &lComm->stage;
struct scclNetSocketComm* rComm = stage->comm;
uint8_t i = stage->iteration;
struct scclSocket* sock = stage->sock;
int ready;
*recvComm = NULL;
if(stage->state == scclNetSocketCommStateAccept)
goto socket_accept_check;
if(stage->state == scclNetSocketCommStateRecv)
goto socket_recv;
SCCLCHECK(scclCalloc(&rComm, 1));
stage->comm = rComm;
rComm->nSocks = lComm->nSocks;
rComm->nThreads = lComm->nThreads;
rComm->dev = lComm->dev;
HIPCHECK(hipGetDevice(&rComm->cudaDev));
for(; i < rComm->nSocks + 1; i++) {
uint8_t sendSockIdx;
SCCLCHECK(scclCalloc(&sock, 1));
SCCLCHECK(scclSocketInit(sock));
stage->sock = sock;
stage->state = scclNetSocketCommStateAccept;
stage->iteration = i;
SCCLCHECK(scclSocketAccept(sock, &lComm->sock));
socket_accept_check:
SCCLCHECK(scclSocketReady(sock, &ready));
if(!ready)
return scclSuccess;
stage->state = scclNetSocketCommStateRecv;
socket_recv:
int done = 0;
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &sendSockIdx, sizeof(uint8_t), &done));
if(done == 0)
return scclSuccess;
if(sendSockIdx == rComm->nSocks)
memcpy(&rComm->ctrlSock, sock, sizeof(struct scclSocket));
else
memcpy(rComm->socks + sendSockIdx, sock, sizeof(struct scclSocket));
free(sock);
}
*recvComm = rComm;
/* reset lComm state */
stage->state = scclNetSocketCommStateStart;
stage->iteration = 0;
stage->sock = NULL;
stage->comm = NULL;
return scclSuccess;
}
scclResult_t scclNetSocketGetRequest(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketRequest** req) {
for(int i = 0; i < MAX_REQUESTS; i++) {
struct scclNetSocketRequest* r = comm->requests + i;
if(r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
r->ctrlSock = &comm->ctrlSock;
r->used = 1;
r->comm = comm;
r->nSubs = 0;
*req = r;
return scclSuccess;
}
}
WARN("NET/Socket : unable to allocate requests");
return scclInternalError;
}
scclResult_t scclNetSocketGetTask(struct scclNetSocketComm* comm, int op, void* data, int size, struct scclNetSocketTask** req) {
int tid = comm->nextSock % comm->nThreads;
struct scclNetSocketThreadResources* res = comm->threadResources + tid;
struct scclNetSocketTaskQueue* queue = &res->threadTaskQueue;
// create helper threads and prepare per-thread task queue
if(queue->tasks == NULL) {
// each request can be divided up to nSocks tasks, and
// these tasks are distributed to nThreads threads,
// we need to make sure each thread queue has enough slots for MAX_REQUESTS
queue->len = MAX_REQUESTS * DIVUP(comm->nSocks, comm->nThreads);
SCCLCHECK(scclCalloc(&queue->tasks, queue->len));
queue->next = 0;
res->comm = comm;
pthread_mutex_init(&res->threadLock, NULL);
pthread_cond_init(&res->threadCond, NULL);
pthread_create(comm->helperThread + tid, NULL, persistentSocketThread, res);
scclSetThreadName(comm->helperThread[tid], "SCCL Sock%c%1u%2u%2u", op == SCCL_SOCKET_SEND ? 'S' : 'R', comm->dev, tid, comm->cudaDev);
}
struct scclNetSocketTask* r = queue->tasks + queue->next;
if(r->used == 0) {
r->op = op;
r->data = data;
r->size = size;
r->sock = comm->socks + comm->nextSock;
r->offset = 0;
r->result = scclSuccess;
comm->nextSock = (comm->nextSock + 1) % comm->nSocks;
r->used = 1;
*req = r;
pthread_mutex_lock(&res->threadLock);
queue->next = (queue->next + 1) % queue->len;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
return scclSuccess;
}
WARN("NET/Socket : unable to allocate subtasks");
return scclInternalError;
}
scclResult_t scclNetSocketTest(void* request, int* done, int* size) {
*done = 0;
struct scclNetSocketRequest* r = (struct scclNetSocketRequest*)request;
if(r == NULL) {
WARN("NET/Socket : test called with NULL request");
return scclInternalError;
}
if(r->used == 1) { /* try to send/recv size */
int data = r->size;
int offset = 0;
SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, &data, sizeof(int), &offset));
if(offset == 0)
return scclSuccess; /* Not ready -- retry later */
// Not sure we could ever receive less than 4 bytes, but just in case ...
if(offset < sizeof(int))
SCCLCHECK(scclSocketWait(r->op, r->ctrlSock, &data, sizeof(int), &offset));
// Check size is less or equal to the size provided by the user
if(r->op == SCCL_SOCKET_RECV && data > r->size) {
char line[SOCKET_NAME_MAXLEN + 1];
union scclSocketAddress addr;
scclSocketGetAddr(r->ctrlSock, &addr);
WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in healthy state, \
there may be a mismatch in collective sizes or environment settings (e.g. SCCL_PROTO, SCCL_ALGO) between ranks",
scclSocketToString(&addr, line),
data,
r->size);
return scclInvalidUsage;
}
r->size = data;
r->offset = 0;
r->used = 2; // done exchanging size
// divide into subtasks
int chunkOffset = 0, i = 0;
if(r->comm->nSocks > 0) {
// each request can be divided up to nSocks tasks
int taskSize = std::max(MIN_CHUNKSIZE, DIVUP(r->size, r->comm->nSocks));
while(chunkOffset < r->size) {
int chunkSize = std::min(taskSize, r->size - chunkOffset);
SCCLCHECK(scclNetSocketGetTask(r->comm, r->op, (char*)(r->data) + chunkOffset, chunkSize, r->tasks + i++));
chunkOffset += chunkSize;
}
}
r->nSubs = i;
}
if(r->used == 2) { // already exchanged size
if(r->nSubs > 0) {
int nCompleted = 0;
for(int i = 0; i < r->nSubs; i++) {
struct scclNetSocketTask* sub = r->tasks[i];
if(sub->result != scclSuccess)
return sub->result;
if(sub->offset == sub->size)
nCompleted++;
}
if(nCompleted == r->nSubs) {
if(size)
*size = r->size;
*done = 1;
r->used = 0;
for(int i = 0; i < r->nSubs; i++) {
struct scclNetSocketTask* sub = r->tasks[i];
sub->used = 0;
}
}
} else { // progress request using main thread
if(r->offset < r->size) {
SCCLCHECK(scclSocketProgress(r->op, r->ctrlSock, r->data, r->size, &r->offset));
}
if(r->offset == r->size) {
if(size)
*size = r->size;
*done = 1;
r->used = 0;
}
}
}
return scclSuccess;
}
scclResult_t scclNetSocketRegMr(void* comm, void* data, int size, int type, void** mhandle) {
return (type != SCCL_PTR_HOST) ? scclInternalError : scclSuccess;
}
scclResult_t scclNetSocketDeregMr(void* comm, void* mhandle) { return scclSuccess; }
scclResult_t scclNetSocketIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) {
struct scclNetSocketComm* comm = (struct scclNetSocketComm*)sendComm;
SCCLCHECK(scclNetSocketGetRequest(comm, SCCL_SOCKET_SEND, data, size, (struct scclNetSocketRequest**)request));
return scclSuccess;
}
scclResult_t scclNetSocketIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) {
struct scclNetSocketComm* comm = (struct scclNetSocketComm*)recvComm;
if(n != 1)
return scclInternalError;
SCCLCHECK(scclNetSocketGetRequest(comm, SCCL_SOCKET_RECV, data[0], sizes[0], (struct scclNetSocketRequest**)request));
return scclSuccess;
}
scclResult_t scclNetSocketIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) {
// We don't support HIP pointers, so we don't need a flush operation
return scclInternalError;
}
scclResult_t scclNetSocketCloseListen(void* opaqueComm) {
struct scclNetSocketListenComm* comm = (struct scclNetSocketListenComm*)opaqueComm;
if(comm) {
int ready;
SCCLCHECK(scclSocketReady(&comm->sock, &ready));
if(ready)
SCCLCHECK(scclSocketClose(&comm->sock));
free(comm);
}
return scclSuccess;
}
scclResult_t scclNetSocketClose(void* opaqueComm) {
struct scclNetSocketComm* comm = (struct scclNetSocketComm*)opaqueComm;
if(comm) {
for(int i = 0; i < comm->nThreads; i++) {
struct scclNetSocketThreadResources* res = comm->threadResources + i;
if(comm->helperThread[i]) {
pthread_mutex_lock(&res->threadLock);
res->stop = 1;
pthread_cond_signal(&res->threadCond);
pthread_mutex_unlock(&res->threadLock);
pthread_join(comm->helperThread[i], NULL);
}
free(res->threadTaskQueue.tasks);
}
int ready;
SCCLCHECK(scclSocketReady(&comm->ctrlSock, &ready));
if(ready)
SCCLCHECK(scclSocketClose(&comm->ctrlSock));
for(int i = 0; i < comm->nSocks; i++) {
SCCLCHECK(scclSocketReady(&comm->socks[i], &ready));
if(ready)
SCCLCHECK(scclSocketClose(&comm->socks[i]));
}
free(comm);
}
return scclSuccess;
}
} // namespace net_socket
scclNet_t scclNetSocket = {"Socket",
net_socket::scclNetSocketInit,
net_socket::scclNetSocketDevices,
net_socket::scclNetSocketGetProperties,
net_socket::scclNetSocketListen,
net_socket::scclNetSocketConnect,
net_socket::scclNetSocketAccept,
net_socket::scclNetSocketRegMr,
NULL, // No DMA-BUF support
net_socket::scclNetSocketDeregMr,
net_socket::scclNetSocketIsend,
net_socket::scclNetSocketIrecv,
net_socket::scclNetSocketIflush,
net_socket::scclNetSocketTest,
net_socket::scclNetSocketClose,
net_socket::scclNetSocketClose,
net_socket::scclNetSocketCloseListen};
} // namespace host
} // namespace net
} // namespace hardware
} // namespace sccl
#pragma once
#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <poll.h>
#include <sys/types.h>
#include <unistd.h>
#include "base.h"
#include "net_utils.h"
namespace sccl {
namespace hardware {
namespace net {
namespace host {
//////////////////////////////////
extern scclNet_t scclNetSocket;
} // namespace host
} // namespace net
} // namespace hardware
} // namespace sccl
#include <string.h>
#include <netdb.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include <unistd.h>
#include <sys/syscall.h>
#include "socket.h"
#include "utils/param.h"
#include "net_utils.h"
namespace sccl {
namespace hardware {
namespace net {
namespace host {
namespace socket_base {
/**
* Allow the user to force the IPv4/IPv6 interface selection
* 根据环境变量 SCCL_SOCKET_FAMILY 获取 socket 地址族
*
* @return 返回 AF_INET(IPv4) 或 AF_INET6(IPv6),若未设置则返回 -1
* @note 环境变量值不区分大小写,支持 "AF_INET" 和 "AF_INET6" 两种取值
*/
static int scclGetSocketFamily(void) {
int family = -1; // Family selection is not forced, will use first one found
const char* env = scclGetEnv("SCCL_SOCKET_FAMILY");
if(env == NULL)
return family;
INFO(SCCL_LOG_NET, "SCCL_SOCKET_FAMILY set by environment to %s", env);
if(strcmp(env, "AF_INET") == 0)
family = AF_INET; // IPv4
else if(strcmp(env, "AF_INET6") == 0)
family = AF_INET6; // IPv6
return family;
}
/**
* @brief 查找并匹配符合条件的所有网络接口
*
* 根据指定的前缀列表、套接字地址族等条件,枚举系统网络接口并筛选出符合条件的接口。
*
* @param prefixList 接口名前缀匹配列表(支持'^'取反和'='精确匹配)
* @param names 输出缓冲区,用于存储匹配的接口名称(每个名称最大maxIfNameSize)
* @param addrs 输出缓冲区,用于存储匹配接口的地址(scclSocketAddress联合体数组)
* @param sock_family 强制指定的地址族(AF_INET/AF_INET6),-1表示不限制
* @param maxIfNameSize 单个接口名称的最大长度
* @param maxIfs 最大返回接口数量
* @return int 实际找到的接口数量
*
* @note 会自动跳过IPv6回环地址,并过滤重复接口(同一接口的IPv4/IPv6地址)
*/
static int findSocketInterfaces(const char* prefixList, char* names, union scclSocketAddress* addrs, int sock_family, int maxIfNameSize, int maxIfs) {
char line[SOCKET_NAME_MAXLEN + 1];
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
if(searchNot)
prefixList++;
bool searchExact = prefixList && prefixList[0] == '=';
if(searchExact)
prefixList++;
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
////////////////////////////////////////////////////////////////////////////
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for(interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
if(interface->ifa_addr == NULL)
continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if(family != AF_INET && family != AF_INET6)
continue;
INFO(SCCL_LOG_NET, "Found interface %s:%s", interface->ifa_name, scclSocketToString((union scclSocketAddress*)interface->ifa_addr, line));
/* Allow the caller to force the socket family type */
if(sock_family != -1 && family != sock_family)
continue;
/* We also need to skip IPv6 loopback interfaces */
if(family == AF_INET6) {
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
if(IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr))
continue;
}
// check against user specified interfaces
if(!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
// Check that this interface has not already been saved
// getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
bool duplicate = false;
for(int i = 0; i < found; i++) {
if(strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) {
duplicate = true;
break;
}
}
if(!duplicate) {
// Store the interface name
strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize);
// Store the IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memset(addrs + found, '\0', sizeof(*addrs));
memcpy(addrs + found, interface->ifa_addr, salen);
found++;
}
}
freeifaddrs(interfaces);
return found;
}
/**
* 检查本地网络接口与远程地址是否属于同一子网
*
* @param local_if 本地网络接口信息
* @param remote 远程socket地址(union类型,支持IPv4/IPv6)
* @return bool 返回true表示属于同一子网,false表示不属于
*
* @note 支持IPv4和IPv6地址族,对于IPv6还会比较scope_id字段
* @warning 不支持的地址族类型会输出警告并返回false
*/
static bool matchSubnet(struct ifaddrs local_if, union scclSocketAddress* remote) {
/* Check family first */
int family = local_if.ifa_addr->sa_family;
if(family != remote->sa.sa_family) {
return false;
}
if(family == AF_INET) {
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
struct sockaddr_in& remote_addr = remote->sin;
struct in_addr local_subnet, remote_subnet;
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
} else if(family == AF_INET6) {
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
struct sockaddr_in6& remote_addr = remote->sin6;
struct in6_addr& local_in6 = local_addr->sin6_addr;
struct in6_addr& mask_in6 = mask->sin6_addr;
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
bool same = true;
int len = 16; // IPv6 address is 16 unsigned char
for(int c = 0; c < len; c++) { // Network byte order is big-endian
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
if(c1 ^ c2) {
same = false;
break;
}
}
// At last, we need to compare scope id
// Two Link-type addresses can have the same subnet address even though they are not in the same scope
// For Global type, this field is 0, so a comparison wouldn't matter
same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
return same;
} else {
WARN("Net : Unsupported address family type");
return false;
}
}
/**
* 将socket地址转换为端口号
*
* @param addr 指向scclSocketAddress联合体的指针,包含socket地址信息
* @return 返回网络字节序转换后的端口号(host字节序)
*
* @note 支持IPv4和IPv6地址类型
*/
static uint16_t socketToPort(union scclSocketAddress* addr) {
struct sockaddr* saddr = &addr->sa;
return ntohs(saddr->sa_family == AF_INET ? addr->sin.sin_port : addr->sin6.sin6_port);
}
////////////////////////////////////////////////////////////////////
static std::vector<std::pair<int, std::unordered_set<std::string>>> clientPortPool;
static scclResult_t socketProgressOpt(int op, struct scclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) {
int bytes = 0;
*closed = 0;
char* data = (char*)ptr;
char line[SOCKET_NAME_MAXLEN + 1];
do {
if(op == SCCL_SOCKET_RECV)
bytes = recv(sock->fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
if(op == SCCL_SOCKET_SEND)
bytes = send(sock->fd, data + (*offset), size - (*offset), block ? MSG_NOSIGNAL : MSG_DONTWAIT | MSG_NOSIGNAL);
if(op == SCCL_SOCKET_RECV && bytes == 0) {
*closed = 1;
return scclSuccess;
}
if(bytes == -1) {
if(errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
WARN("socketProgressOpt: Call to recv from %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
return scclRemoteError;
} else {
bytes = 0;
}
}
(*offset) += bytes;
if(sock->abortFlag && *sock->abortFlag != 0) {
INFO(SCCL_LOG_NET, "socketProgressOpt: abort called");
return scclInternalError;
}
} while(bytes > 0 && (*offset) < size);
return scclSuccess;
}
static scclResult_t socketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
int closed;
SCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed));
if(closed) {
char line[SOCKET_NAME_MAXLEN + 1];
WARN("socketProgress: Connection closed by remote peer %s", scclSocketToString(&sock->addr, line, 0));
return scclRemoteError;
}
return scclSuccess;
}
static scclResult_t socketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
while(*offset < size)
SCCLCHECK(socketProgress(op, sock, ptr, size, offset));
return scclSuccess;
}
static scclResult_t socketTryAccept(struct scclSocket* sock) {
socklen_t socklen = sizeof(union scclSocketAddress);
sock->fd = accept(sock->acceptFd, &sock->addr.sa, &socklen);
if(sock->fd != -1) {
sock->state = scclSocketStateAccepted;
} else if(errno != EAGAIN && errno != EWOULDBLOCK) {
WARN("socketTryAccept: Accept failed: %s", strerror(errno));
return scclSystemError;
}
return scclSuccess;
}
static scclResult_t socketFinalizeAccept(struct scclSocket* sock) {
uint64_t magic;
enum scclSocketType type;
int received = 0;
const int one = 1;
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
SCCLCHECK(scclSocketProgress(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
if(received == 0)
return scclSuccess;
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received));
if(magic != sock->magic) {
WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic);
close(sock->fd);
sock->fd = -1;
// Ignore spurious connection and accept again
sock->state = scclSocketStateAccepting;
return scclSuccess;
} else {
received = 0;
SCCLCHECK(socketWait(SCCL_SOCKET_RECV, sock, &type, sizeof(type), &received));
if(type != sock->type) {
WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type);
sock->state = scclSocketStateError;
close(sock->fd);
sock->fd = -1;
return scclInternalError;
} else {
sock->state = scclSocketStateReady;
}
}
return scclSuccess;
}
static scclResult_t socketStartConnect(struct scclSocket* sock) {
/* blocking/non-blocking connect() is determined by asyncFlag. */
int ret = connect(sock->fd, &sock->addr.sa, sock->salen);
if(ret == 0) {
sock->state = scclSocketStateConnected;
return scclSuccess;
} else if(errno == EINPROGRESS) {
sock->state = scclSocketStateConnectPolling;
return scclSuccess;
} else if(errno == ECONNREFUSED) {
if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
sock->state = scclSocketStateError;
WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
if(sock->refusedRetries % 1000 == 0)
INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
return scclSuccess;
} else if(errno == ETIMEDOUT) {
if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
sock->state = scclSocketStateError;
WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
return scclSuccess;
} else {
char line[SOCKET_NAME_MAXLEN + 1];
sock->state = scclSocketStateError;
WARN("socketStartConnect: Connect to %s failed : %s", scclSocketToString(&sock->addr, line), strerror(errno));
return scclSystemError;
}
}
static scclResult_t socketPollConnect(struct scclSocket* sock) {
struct pollfd pfd;
int timeout = 1, ret;
socklen_t rlen = sizeof(int);
memset(&pfd, 0, sizeof(struct pollfd));
pfd.fd = sock->fd;
pfd.events = POLLOUT;
ret = poll(&pfd, 1, timeout);
if(ret == 0 || (ret < 0 && errno == EINTR)) {
return scclSuccess;
} else if(ret < 0) {
WARN("socketPollConnect poll() failed with error %s", strerror(errno));
return scclRemoteError;
} else {
EQCHECK(ret == 1 && (pfd.revents & POLLOUT), 0);
}
/* check socket status */
SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt");
if(ret == 0) {
sock->state = scclSocketStateConnected;
} else if(ret == ECONNREFUSED) {
if(++sock->refusedRetries == RETRY_REFUSED_TIMES) {
sock->state = scclSocketStateError;
WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries);
return scclRemoteError;
}
if(sock->refusedRetries % 1000 == 0)
INFO(SCCL_LOG_CODEALL, "Call to connect returned %s, retrying", strerror(errno));
usleep(SLEEP_INT);
sock->state = scclSocketStateConnecting;
} else if(ret == ETIMEDOUT) {
if(++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) {
sock->state = scclSocketStateError;
WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries);
return scclRemoteError;
}
usleep(SLEEP_INT);
sock->state = scclSocketStateConnecting;
} else if(ret != EINPROGRESS) {
sock->state = scclSocketStateError;
return scclSystemError;
}
return scclSuccess;
}
static scclResult_t socketFinalizeConnect(struct scclSocket* sock) {
int sent = 0;
SCCLCHECK(socketProgress(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
if(sent == 0)
return scclSuccess;
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent));
sent = 0;
SCCLCHECK(socketWait(SCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent));
sock->state = scclSocketStateReady;
return scclSuccess;
}
static scclResult_t socketProgressState(struct host::scclSocket* sock) {
if(sock->state == scclSocketStateAccepting) {
SCCLCHECK(socketTryAccept(sock));
}
if(sock->state == scclSocketStateAccepted) {
SCCLCHECK(socketFinalizeAccept(sock));
}
if(sock->state == scclSocketStateConnecting) {
SCCLCHECK(socketStartConnect(sock));
}
if(sock->state == scclSocketStateConnectPolling) {
SCCLCHECK(socketPollConnect(sock));
}
if(sock->state == scclSocketStateConnected) {
SCCLCHECK(socketFinalizeConnect(sock));
}
return scclSuccess;
}
} // namespace socket_base
/**
* 将 scclSocketAddress 转换为可读字符串
*
* @param addr 指向 scclSocketAddress 联合体的指针,包含要转换的地址
* @param buf 用于存储结果的缓冲区
* @param numericHostForm 标志位,1表示使用数字主机格式(默认),0表示尝试解析主机名
* @return 成功返回转换后的字符串指针,失败返回NULL
*
* @note 支持IPv4和IPv6地址转换,其他地址类型会返回空字符串
* 格式为"主机名<端口号>",主机名可能为IP地址或域名
*/
const char* scclSocketToString(const union scclSocketAddress* addr, char* buf, const int numericHostForm /*= 1*/) {
if(buf == NULL || addr == NULL)
return NULL;
const struct sockaddr* saddr = &addr->sa;
if(saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
buf[0] = '\0';
return buf;
}
char host[NI_MAXHOST], service[NI_MAXSERV];
/* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned.
* (When not set, this will still happen in case the node's name cannot be determined.)
*/
int flag = NI_NUMERICSERV | (numericHostForm ? NI_NUMERICHOST : 0);
(void)getnameinfo(saddr, sizeof(union scclSocketAddress), host, NI_MAXHOST, service, NI_MAXSERV, flag);
sprintf(buf, "%s<%s>", host, service);
return buf;
}
/**
* @brief 从字符串解析获取socket地址
*
* 该函数用于将IP:PORT格式的字符串解析为socket地址结构体,支持IPv4和IPv6格式。
* IPv4格式示例: "192.168.1.1:8080"
* IPv6格式示例: "[fe80::1%eth0]:8080"
*
* @param ua 输出参数,用于存储解析后的socket地址
* @param ip_port_pair 输入字符串,包含IP地址和端口号
* @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回错误码
*/
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair) {
if(!(ip_port_pair && strlen(ip_port_pair) > 1)) {
WARN("Net : string is null");
return scclInvalidArgument;
}
bool ipv6 = ip_port_pair[0] == '[';
/* Construct the sockaddress structure */
if(!ipv6) {
struct netIf ni;
// parse <ip_or_hostname>:<port> string, expect one pair
if(parseStringList(ip_port_pair, &ni, 1) != 1) {
WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
return scclInvalidArgument;
}
struct addrinfo hints, *p;
int rv;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
return scclInvalidArgument;
}
// use the first
if(p->ai_family == AF_INET) {
struct sockaddr_in& sin = ua->sin;
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
sin.sin_family = AF_INET; // IPv4
// inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
sin.sin_port = htons(ni.port); // port
} else if(p->ai_family == AF_INET6) {
struct sockaddr_in6& sin6 = ua->sin6;
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
sin6.sin6_family = AF_INET6; // IPv6
sin6.sin6_port = htons(ni.port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = 0; // should be global scope, set to 0
} else {
WARN("Net : unsupported IP family");
freeaddrinfo(p);
return scclInvalidArgument;
}
freeaddrinfo(p); // all done with this structure
} else {
int i, j = -1, len = strlen(ip_port_pair);
for(i = 1; i < len; i++) {
if(ip_port_pair[i] == '%')
j = i;
if(ip_port_pair[i] == ']')
break;
}
if(i == len) {
WARN("Net : No valid [IPv6]:port pair found");
return scclInvalidArgument;
}
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
memset(ip_str, '\0', sizeof(ip_str));
memset(port_str, '\0', sizeof(port_str));
memset(if_name, '\0', sizeof(if_name));
strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
int port = atoi(port_str);
if(!global_scope)
strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name
struct sockaddr_in6& sin6 = ua->sin6;
sin6.sin6_family = AF_INET6; // IPv6
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
sin6.sin6_port = htons(port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
}
return scclSuccess;
}
/**
* 查找与远程地址在同一子网的本地网络接口
*
* @param ifNames 输出参数,存储匹配的接口名称数组
* @param localAddrs 输出参数,存储匹配接口的本地地址数组
* @param remoteAddr 输入参数,远程地址用于子网匹配
* @param ifNameMaxSize 单个接口名称的最大长度
* @param maxIfs 最大可返回的接口数量
* @return 返回找到的匹配接口数量,若未找到返回0
*
* @note 仅支持IPv4和IPv6地址
* @warning 调用者需确保ifNames和localAddrs数组足够大以容纳maxIfs个结果
*/
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) {
char line[SOCKET_NAME_MAXLEN + 1];
char line_a[SOCKET_NAME_MAXLEN + 1];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for(interface = interfaces; interface && !found; interface = interface->ifa_next) {
if(interface->ifa_addr == NULL)
continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if(family != AF_INET && family != AF_INET6)
continue;
// check against user specified interfaces
if(!socket_base::matchSubnet(*interface, remoteAddr)) {
continue;
}
// Store the local IP address
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
memcpy(localAddrs + found, interface->ifa_addr, salen);
// Store the interface name
strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
INFO(SCCL_LOG_NET,
"NET : Found interface %s:%s in the same subnet as remote address %s",
interface->ifa_name,
scclSocketToString(localAddrs + found, line),
scclSocketToString(remoteAddr, line_a));
found++;
if(found == maxIfs)
break;
}
if(found == 0) {
WARN("Net : No interface found in the same subnet as remote address %s", scclSocketToString(remoteAddr, line_a));
}
freeifaddrs(interfaces);
return found;
}
/**
* @brief 查找可用的socket网络接口
*
* 该函数用于查找系统中可用的网络接口,支持通过环境变量指定接口或自动探测。
* 查找顺序:1) 用户指定的接口(SCCL_SOCKET_IFNAME) 2) IB接口 3) 与SCCL_COMM_ID同子网的接口
* 4) 排除docker和lo的其他接口 5) docker接口 6) lo接口
*
* @param ifNames 输出参数,存储找到的接口名称
* @param ifAddrs 输出参数,存储找到的接口地址
* @param ifNameMaxSize 单个接口名称的最大长度
* @param maxIfs 最大支持的接口数量
* @return int 返回找到的接口数量,0表示未找到任何接口
*/
int scclFindSocketInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) {
static int shownIfName = 0;
int nIfs = 0;
// Allow user to force the INET socket family selection
int sock_family = socket_base::scclGetSocketFamily();
// User specified interface
char* env = getenv("SCCL_SOCKET_IFNAME");
if(env && strlen(env) > 1) {
INFO(SCCL_LOG_NET, "SCCL_SOCKET_IFNAME set by environment to %s", env);
// Specified by user : find or fail
if(shownIfName++ == 0)
INFO(SCCL_LOG_NET, "SCCL_SOCKET_IFNAME set to %s", env);
nIfs = socket_base::findSocketInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
} else {
// Try to automatically pick the right one
// Start with IB
nIfs = socket_base::findSocketInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// else see if we can get some hint from COMM ID
if(nIfs == 0) {
char* commId = getenv("SCCL_COMM_ID");
if(commId && strlen(commId) > 1) {
INFO(SCCL_LOG_NET, "SCCL_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
union scclSocketAddress idAddr;
scclSocketGetAddrFromString(&idAddr, commId);
nIfs = scclFindInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
}
}
if(nIfs == 0) {
WARN("No socket network interface found. ");
}
// // Then look for anything else (but not docker or lo)
// if(nIfs == 0)
// nIfs = socket_base::findSocketInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// // Finally look for docker, then lo.
// if(nIfs == 0)
// nIfs = socket_base::findSocketInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// if(nIfs == 0)
// nIfs = socket_base::findSocketInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
}
return nIfs;
}
//////////////////////////////////// socket基础操作 ////////////////////////////////////
scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
if(sock == NULL) {
WARN("scclSocketProgress: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socket_base::socketProgress(op, sock, ptr, size, offset));
return scclSuccess;
}
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset) {
if(sock == NULL) {
WARN("scclSocketWait: pass NULL socket");
return scclInvalidArgument;
}
SCCLCHECK(socket_base::socketWait(op, sock, ptr, size, offset));
return scclSuccess;
}
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketSend: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady) {
WARN("scclSocketSend: socket state (%d) is not ready", sock->state);
return scclInternalError;
}
SCCLCHECK(socket_base::socketWait(SCCL_SOCKET_SEND, sock, ptr, size, &offset));
return scclSuccess;
}
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketRecv: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady) {
WARN("scclSocketRecv: socket state (%d) is not ready", sock->state);
return scclInternalError;
}
SCCLCHECK(socket_base::socketWait(SCCL_SOCKET_RECV, sock, ptr, size, &offset));
return scclSuccess;
}
// Receive or detect connection closed
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking) {
int offset = 0;
if(sock == NULL) {
WARN("scclSocketTryRecv: pass NULL socket");
return scclInvalidArgument;
}
*closed = 0;
// Block until connection closes or nbytes received
if(blocking) {
while(offset < size) {
SCCLCHECK(socket_base::socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
}
} else {
SCCLCHECK(socket_base::socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
// If any bytes were received, block waiting for the rest
if(offset > 0) {
while(offset < size) {
SCCLCHECK(socket_base::socketProgressOpt(SCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed));
if(*closed)
return scclSuccess;
}
// No bytes were received, return scclInProgress
} else {
return scclInProgress;
}
}
return scclSuccess;
}
scclResult_t scclSocketClose(struct scclSocket* sock) {
if(sock != NULL) {
if(sock->fd >= 0) {
/* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected
* by refcount of fd, but close() is. close() won't close a fd and send FIN packet if
* the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful
* connection close here. */
shutdown(sock->fd, SHUT_RDWR);
close(sock->fd);
}
sock->state = scclSocketStateClosed;
sock->fd = -1;
}
return scclSuccess;
}
//////////////////////////////////// 应用socket ////////////////////////////////////
/**
* 初始化socket结构体并创建socket连接
*
* @param sock 要初始化的socket结构体指针
* @param addr 目标地址,可为NULL
* @param magic 用于验证的magic number
* @param type socket类型
* @param abortFlag 用于异步终止的标志位指针
* @param asyncFlag 是否异步模式标志
* @return 成功返回scclSuccess,失败返回错误码
*
* @note 如果addr不为NULL,会创建并连接socket
* 支持IPv4/IPv6地址族
* 在异步模式下会将socket设为非阻塞
*/
scclResult_t
scclSocketInit(struct scclSocket* sock, union scclSocketAddress* addr, uint64_t magic, enum scclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) {
scclResult_t ret = scclSuccess;
if(sock == NULL)
goto exit;
sock->timedOutRetries = 0;
sock->refusedRetries = 0;
sock->abortFlag = abortFlag;
sock->asyncFlag = asyncFlag;
sock->state = scclSocketStateInitialized;
sock->magic = magic;
sock->type = type;
sock->fd = -1;
sock->acceptFd = -1;
if(addr) {
/* IPv4/IPv6 support */
int family;
memcpy(&sock->addr, addr, sizeof(union scclSocketAddress));
family = sock->addr.sa.sa_family;
if(family != AF_INET && family != AF_INET6) {
char line[SOCKET_NAME_MAXLEN + 1];
WARN("scclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
scclSocketToString(&sock->addr, line),
family,
AF_INET,
AF_INET6);
ret = scclInternalError;
goto fail;
}
sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
/* Connect to a hostname / port */
sock->fd = socket(family, SOCK_STREAM, 0);
if(sock->fd == -1) {
WARN("scclSocketInit: Socket creation failed : %s", strerror(errno));
ret = scclSystemError;
goto fail;
}
} else {
memset(&sock->addr, 0, sizeof(union scclSocketAddress));
}
/* Set socket as non-blocking if async or if we need to be able to abort */
if((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) {
int flags;
EQCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), -1, ret, fail);
EQCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), -1, ret, fail);
}
exit:
return ret;
fail:
goto exit;
}
/**
* 监听指定套接字上的连接请求
*
* @param sock 要监听的套接字指针,必须已初始化且未处于监听状态
* @return scclResult_t 返回操作结果:
* - scclSuccess: 监听成功
* - scclInvalidArgument: 参数无效(sock为NULL或文件描述符无效)
*
* @note 如果环境变量强制指定了端口,会设置SO_REUSEADDR/SO_REUSEPORT选项
* @note 实际监听的端口号会通过getsockname获取并更新到sock->addr中
* @note 待处理连接队列的最大长度会被系统限制(/proc/sys/net/core/somaxconn)
*/
scclResult_t scclSocketListen(struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketListen: pass NULL socket");
return scclInvalidArgument;
}
if(sock->fd == -1) {
WARN("scclSocketListen: file descriptor is -1");
return scclInvalidArgument;
}
if(socket_base::socketToPort(&sock->addr)) {
// Port is forced by env. Make sure we get the port.
// 端口由环境变量强制指定。确保我们获取了端口。
int opt = 1;
#if defined(SO_REUSEPORT)
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
#else
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
#endif
}
// addr port should be 0 (Any port)
// 地址端口应为0(任意端口)
SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind");
/* Get the assigned Port */
// 获取分配的端口
socklen_t size = sock->salen;
SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname");
char line[SOCKET_NAME_MAXLEN + 1];
INFO(SCCL_LOG_NET, "Listening on socket %s", scclSocketToString(&sock->addr, line));
/* Put the socket in listen mode
* NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
*/
// 将套接字置于监听模式
// 注意:/proc/sys/net/core/somaxconn中的值会悄悄截断待处理连接队列的长度
SYSCHECK(listen(sock->fd, 16384), "listen");
sock->state = scclSocketStateReady;
return scclSuccess;
}
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr) {
if(sock == NULL) {
WARN("scclSocketGetAddr: pass NULL socket");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateReady)
return scclInternalError;
memcpy(addr, &sock->addr, sizeof(union scclSocketAddress));
return scclSuccess;
}
/**
* @brief 建立socket连接
*
* 该函数用于建立socket连接,支持端口复用功能。函数会检查socket状态,
* 设置TCP_NODELAY选项,并根据portReuse参数决定是否进行端口复用。
* 端口复用时会从预定义的端口池中选择可用端口进行绑定。
*
* @param sock 指向scclSocket结构体的指针
* @param portReuse 是否启用端口复用 (1启用/0禁用)
*
* @return scclResult_t 返回操作结果状态码:
* - scclSuccess: 操作成功
* - scclInvalidArgument: 参数无效
* - scclInternalError: 内部错误
* - scclRemoteError: 远程错误
* - scclSystemError: 系统错误
*/
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse) {
char line[SOCKET_NAME_MAXLEN + 1];
const int one = 1;
if(sock == NULL) {
WARN("scclSocketConnect: pass NULL socket");
return scclInvalidArgument;
}
if(sock->fd == -1) {
WARN("scclSocketConnect: file descriptor is -1");
return scclInvalidArgument;
}
if(sock->state != scclSocketStateInitialized) {
WARN("scclSocketConnect: wrong socket state %d", sock->state);
if(sock->state == scclSocketStateError)
return scclRemoteError;
return scclInternalError;
}
INFO(SCCL_LOG_NET, "Connecting to socket %s", scclSocketToString(&sock->addr, line));
SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
if(portReuse) {
int family = sock->addr.sa.sa_family;
if(family != AF_INET && family != AF_INET6) {
WARN("Net : connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)",
scclSocketToString(&sock->addr, line),
family,
AF_INET,
AF_INET6);
return scclInternalError;
}
int salen = (family == AF_INET) ? sizeof(struct sockaddr_in)
: sizeof(struct sockaddr_in6); // pre-define ports according to tid, to avoid extra lock for race condition
if(socket_base::clientPortPool.size() == 0) {
for(int tid = syscall(SYS_gettid), i = 1; i < 5; i++) {
socket_base::clientPortPool.push_back(std::make_pair(60000 + i * 1000 + tid % 1000, std::unordered_set<std::string>()));
}
}
// find a port without conflict (different remote peer) in best effort
int reused_port = -1;
std::string remote_peer(scclSocketToString(&sock->addr, line));
for(auto& port : socket_base::clientPortPool) {
if(port.second.find(remote_peer) == port.second.end()) {
reused_port = port.first;
port.second.insert(remote_peer);
break;
}
}
// bind the port in fd for connect system call
if(reused_port != -1) {
int opt = 1;
SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
struct sockaddr_in sin;
sin.sin_family = family;
sin.sin_addr.s_addr = htonl(INADDR_ANY);
sin.sin_port = htons(reused_port);
SYSCHECK(bind(sock->fd, (struct sockaddr*)&sin, salen), "bind_client_port");
}
}
sock->state = scclSocketStateConnecting;
do {
SCCLCHECK(socket_base::socketProgressState(sock));
} while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
(sock->state == scclSocketStateConnecting || sock->state == scclSocketStateConnectPolling || sock->state == scclSocketStateConnected));
if(sock->abortFlag && *sock->abortFlag != 0)
return scclInternalError;
switch(sock->state) {
case scclSocketStateConnecting:
case scclSocketStateConnectPolling:
case scclSocketStateConnected:
case scclSocketStateReady: return scclSuccess;
case scclSocketStateError: return scclSystemError;
default: WARN("scclSocketConnect: wrong socket state %d", sock->state); return scclInternalError;
}
}
/**
* 检查socket是否就绪并更新运行状态
*
* @param sock 要检查的socket指针,如果为NULL则直接设置running为0
* @param running 输出参数,表示socket是否就绪(1)或未就绪(0)
* @return scclSuccess 操作成功
* @return scclRemoteError 当socket处于错误或关闭状态时返回
*
* 该函数会检查socket状态,如果未就绪会尝试推进socket状态。
* 最终通过running参数返回socket是否就绪。
*/
scclResult_t scclSocketReady(struct scclSocket* sock, int* running) {
if(sock == NULL) {
*running = 0;
return scclSuccess;
}
if(sock->state == scclSocketStateError || sock->state == scclSocketStateClosed) {
WARN("scclSocketReady: unexpected socket state %d", sock->state);
return scclRemoteError;
}
*running = (sock->state == scclSocketStateReady) ? 1 : 0;
if(*running == 0) {
SCCLCHECK(socket_base::socketProgressState(sock));
*running = (sock->state == scclSocketStateReady) ? 1 : 0;
}
return scclSuccess;
}
/**
* @brief 接受一个socket连接
*
* @param sock 用于接收连接的socket对象
* @param listenSock 监听状态的socket对象
* @return scclResult_t 返回操作结果状态码:
* - scclSuccess: 操作成功
* - scclInvalidArgument: 参数无效
* - scclSystemError: 系统错误
* - scclInternalError: 内部错误
*
* @note 该函数会阻塞直到连接被接受或发生错误。如果设置了异步标志(asyncFlag),
* 则会在后台处理连接请求。可以通过abortFlag来中止操作。
*/
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* listenSock) {
scclResult_t ret = scclSuccess;
if(listenSock == NULL || sock == NULL) {
WARN("scclSocketAccept: pass NULL socket");
ret = scclInvalidArgument;
goto exit;
}
if(listenSock->state != scclSocketStateReady) {
WARN("scclSocketAccept: wrong socket state %d", listenSock->state);
if(listenSock->state == scclSocketStateError)
ret = scclSystemError;
else
ret = scclInternalError;
goto exit;
}
if(sock->acceptFd == -1) {
memcpy(sock, listenSock, sizeof(struct scclSocket));
sock->acceptFd = listenSock->fd;
sock->state = scclSocketStateAccepting;
}
do {
SCCLCHECKGOTO(socket_base::socketProgressState(sock), ret, exit);
} while(sock->asyncFlag == 0 && (sock->abortFlag == NULL || *sock->abortFlag == 0) &&
(sock->state == scclSocketStateAccepting || sock->state == scclSocketStateAccepted));
if(sock->abortFlag && *sock->abortFlag != 0)
return scclInternalError;
switch(sock->state) {
case scclSocketStateAccepting:
case scclSocketStateAccepted:
case scclSocketStateReady: ret = scclSuccess; break;
case scclSocketStateError: ret = scclSystemError; break;
default:
WARN("scclSocketAccept: wrong socket state %d", sock->state);
ret = scclInternalError;
break;
}
exit:
return ret;
}
/**
* 获取socket的文件描述符
*
* @param sock 指向scclSocket结构体的指针
* @param fd 用于返回文件描述符的指针
* @return scclResult_t 返回操作结果:
* - scclSuccess 成功获取文件描述符
* - scclInvalidArgument 传入的socket指针为NULL
*/
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd) {
if(sock == NULL) {
WARN("scclSocketGetFd: pass NULL socket");
return scclInvalidArgument;
}
if(fd)
*fd = sock->fd;
return scclSuccess;
}
/**
* 设置socket的文件描述符
*
* @param fd 要设置的文件描述符
* @param sock 目标socket结构体指针
* @return scclResult_t 返回操作结果:scclSuccess表示成功,scclInvalidArgument表示参数无效
*/
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock) {
if(sock == NULL) {
WARN("scclSocketGetFd: pass NULL socket");
return scclInvalidArgument;
}
sock->fd = fd;
return scclSuccess;
}
} // namespace host
} // namespace net
} // namespace hardware
} // namespace sccl
#pragma once
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <fcntl.h>
#include <poll.h>
#include "base.h"
namespace sccl {
namespace hardware {
namespace net {
namespace host {
#define MAX_IFS 16 // 最大接口数量
#define MAX_IF_NAME_SIZE 16 // 每个接口名称的最大长度
#define SLEEP_INT 1000 // 连接重试的休眠间隔,单位为微秒
#define RETRY_REFUSED_TIMES 2e4 // 在报告超时之前,连接被拒绝的重试次数(总计20秒)
#define RETRY_TIMEDOUT_TIMES 3 // 连接超时的重试次数(每次重试可能需要20秒)
#define SOCKET_NAME_MAXLEN (NI_MAXHOST + NI_MAXSERV) // 套接字名称的最大长度,包括主机名和服务名
#define SCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL // 用于标识套接字的魔数
/* 用于存储IPv4/IPv6通用套接字地址的联合体 */
union scclSocketAddress { // 联合体用于存储不同类型的套接字地址
struct sockaddr sa; // 通用套接字地址
struct sockaddr_in sin; // IPv4套接字地址
struct sockaddr_in6 sin6; // IPv6套接字地址
};
enum scclSocketState : uint8_t {
scclSocketStateNone = 0, // 未定义状态
scclSocketStateInitialized = 1, // 已初始化状态
scclSocketStateAccepting = 2, // 正在接受连接状态
scclSocketStateAccepted = 3, // 已接受连接状态
scclSocketStateConnecting = 4, // 正在连接状态
scclSocketStateConnectPolling = 5, // 连接轮询状态
scclSocketStateConnected = 6, // 已连接状态
scclSocketStateReady = 7, // 准备就绪状态
scclSocketStateClosed = 8, // 已关闭状态
scclSocketStateError = 9, // 错误状态
scclSocketStateNum = 10 // 状态总数
};
enum scclSocketType : uint8_t {
scclSocketTypeUnknown = 0, // 未知类型
scclSocketTypeBootstrap = 1, // 启动类型
scclSocketTypeProxy = 2, // 代理类型
scclSocketTypeNetSocket = 3, // 网络套接字类型
scclSocketTypeNetIb = 4 // 网络Infiniband类型
};
struct scclSocket {
int fd; // 文件描述符
int acceptFd; // 接受连接的文件描述符
int timedOutRetries; // 超时重试次数
int refusedRetries; // 被拒绝重试次数
union scclSocketAddress addr; // 套接字地址
volatile uint32_t* abortFlag; // 中止标志
int asyncFlag; // 异步标志
enum scclSocketState state; // 套接字状态
int salen; // 地址长度
uint64_t magic; // 魔术数
enum scclSocketType type; // 套接字类型
};
#define SCCL_SOCKET_SEND 0
#define SCCL_SOCKET_RECV 1
//////////////////////////////////// socket工具 ////////////////////////////////////
// 将地址转换为字符串
const char* scclSocketToString(const union scclSocketAddress* addr, char* buf, const int numericHostForm = 1);
// 从字符串中获取地址
scclResult_t scclSocketGetAddrFromString(union scclSocketAddress* ua, const char* ip_port_pair);
// 查找与子网匹配的接口
int scclFindInterfaceMatchSubnet(char* ifNames, union scclSocketAddress* localAddrs, union scclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs);
// 查找可用的socket网络接口
int scclFindSocketInterfaces(char* ifNames, union scclSocketAddress* ifAddrs, int ifNameMaxSize, int maxIfs);
//////////////////////////////////// socket基础操作 ////////////////////////////////////
// 进行socket操作
scclResult_t scclSocketProgress(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
// 等待socket操作完成
scclResult_t scclSocketWait(int op, struct scclSocket* sock, void* ptr, int size, int* offset);
// 发送数据
scclResult_t scclSocketSend(struct scclSocket* sock, void* ptr, int size);
// 接收数据
scclResult_t scclSocketRecv(struct scclSocket* sock, void* ptr, int size);
// 尝试接收数据
scclResult_t scclSocketTryRecv(struct scclSocket* sock, void* ptr, int size, int* closed, bool blocking);
// 关闭socket
scclResult_t scclSocketClose(struct scclSocket* sock);
//////////////////////////////////// 应用socket ////////////////////////////////////
// 初始化一个socket
scclResult_t scclSocketInit(struct scclSocket* sock,
union scclSocketAddress* addr = NULL,
uint64_t magic = SCCL_SOCKET_MAGIC,
enum scclSocketType type = scclSocketTypeUnknown,
volatile uint32_t* abortFlag = NULL,
int asyncFlag = 0);
// 创建一个监听socket。sock->addr可以预先填充IP和端口信息。成功调用后设置sock->fd
scclResult_t scclSocketListen(struct scclSocket* sock);
// 获取socket地址
scclResult_t scclSocketGetAddr(struct scclSocket* sock, union scclSocketAddress* addr);
// 连接到sock->addr。成功调用后设置sock->fd。
scclResult_t scclSocketConnect(struct scclSocket* sock, int portReuse = 0);
// 返回socket连接状态。
scclResult_t scclSocketReady(struct scclSocket* sock, int* running);
// 接受来自listenSock->fd的传入连接,并在sock->fd中保持文件描述符,远程端IP/端口在sock->addr中。
scclResult_t scclSocketAccept(struct scclSocket* sock, struct scclSocket* ulistenSock);
// 获取socket文件描述符
scclResult_t scclSocketGetFd(struct scclSocket* sock, int* fd);
// 设置socket文件描述符
scclResult_t scclSocketSetFd(int fd, struct scclSocket* sock);
} // namespace host
} // namespace net
} // namespace hardware
} // namespace sccl
#pragma once
#include <stdint.h>
#include "base.h"
#include "net_utils.h"
#include "device/net_ib.h"
#include "host/net_socket.h"
namespace sccl {
namespace hardware {
namespace net {
//////////////////////////////////
typedef enum net_type : uint8_t {
NET_IB = 0,
NET_SOCKET = 1
} net_type_t;
//////////////////////////////////
inline scclResult_t initNetSpecial(scclNet_t* net) {
int ndev;
// 初始化网络,如果初始化失败则返回内部错误
if(net->init() != scclSuccess)
return scclInternalError;
// 获取设备数量,如果获取失败则返回内部错误
if(net->devices(&ndev) != scclSuccess)
return scclInternalError;
// 如果设备数量小于或等于0,则返回系统错误
if(ndev <= 0)
return scclSystemError;
return scclSuccess;
}
/**
* 初始化网络设备
*
* @param net 指向scclNet_t结构体的指针,表示要初始化的网络设备
* @return scclResult_t 返回操作结果:
* - scclSuccess: 初始化成功
* - scclInternalError: 网络初始化或获取设备数量失败
* - scclSystemError: 系统中无可用设备
*/
inline scclNet_t* initNet(net_type_t t) {
scclNet_t* scclNet = NULL;
if(t == NET_IB) {
if(initNetSpecial(&(device::scclNetIb)) == scclSuccess) {
scclNet = &(device::scclNetIb);
}
} else if(t == NET_SOCKET) {
if(initNetSpecial(&(host::scclNetSocket)) == scclSuccess) {
scclNet = &(host::scclNetSocket);
}
} else {
WARN("Unsupported network type.");
}
return scclNet;
}
////////////////////////////////////
inline scclNet_t* scclNets[3] = {nullptr, &device::scclNetIb, &host::scclNetSocket};
} // namespace net
} // namespace hardware
} // namespace sccl
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include "net_utils.h"
namespace sccl {
namespace hardware {
namespace net {
/**
* 解析网络接口字符串列表
*
* @param string 输入字符串,格式为"前缀:端口,前缀:端口,..."
* @param ifList 存储解析结果的网络接口数组
* @param maxList 最大可解析的接口数量
* @return 实际解析出的接口数量
*
* 功能说明:
* 将形如"eth0:1234,ib0:5678"的字符串解析为网络接口结构体数组。
* 每个接口包含前缀和端口号,未指定端口时设为-1。
* 遇到maxList限制或字符串结束时停止解析。
*/
int parseStringList(const char* string, struct netIf* ifList, int maxList) {
if(!string)
return 0;
const char* ptr = string;
int ifNum = 0;
int ifC = 0;
char c;
do {
c = *ptr;
if(c == ':') {
if(ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = atoi(ptr + 1);
ifNum++;
ifC = 0;
}
while(c != ',' && c != '\0')
c = *(++ptr);
} else if(c == ',' || c == '\0') {
if(ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = -1;
ifNum++;
ifC = 0;
}
} else {
ifList[ifNum].prefix[ifC] = c;
ifC++;
}
ptr++;
} while(ifNum < maxList && c);
return ifNum;
}
static bool matchIf(const char* string, const char* ref, bool matchExact) {
// Make sure to include '\0' in the exact case
int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
return strncmp(string, ref, matchLen) == 0;
}
static bool matchPort(const int port1, const int port2) {
if(port1 == -1)
return true;
if(port2 == -1)
return true;
if(port1 == port2)
return true;
return false;
}
/**
* 检查给定的字符串和端口是否匹配网络接口列表中的任意一项
*
* @param string 待匹配的字符串
* @param port 待匹配的端口号
* @param ifList 网络接口列表
* @param listSize 网络接口列表大小
* @param matchExact 是否要求精确匹配
* @return 如果匹配成功返回true,否则返回false
* @note 当listSize为0时,默认返回true
*/
bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) {
// Make an exception for the case where no user list is defined
if(listSize == 0)
return true;
for(int i = 0; i < listSize; i++) {
if(matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
return true;
}
}
return false;
}
} // namespace net
} // namespace hardware
} // namespace sccl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment