#include "comm.h"
#include "graph.h"
#include "trees.h"
#include "rings.h"
#include "topo.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace detect {

/******************************************************************/
/********************* Internode connection ***********************/
/******************************************************************/

scclResult_t scclTopoPreset(struct scclComm* comm, struct scclTopoGraph** graphs, struct scclTopoRanks* topoRanks) {
    int rank       = comm->rank;
    int localRanks = comm->topo->nodes[GPU].count;
    int nChannels  = comm->nChannels;

    for(int c = 0; c < nChannels; c++) {
        struct scclChannel* channel = comm->channels + c;
        channel->ring.prev = channel->ring.next = -1;
        channel->tree.up                        = -1;
        channel->collnetChain.up                = -1;
        for(int i = 0; i < SCCL_MAX_TREE_ARITY; i++)
            channel->tree.down[i] = -1;
        for(int i = 0; i < SCCL_MAX_TREE_ARITY; i++)
            channel->collnetChain.down[i] = -1;
        channel->collnetDirect.out      = -1;
        channel->collnetDirect.headRank = -1;
        channel->collnetDirect.nHeads   = 0;
        channel->collnetDirect.shift    = 0;
        for(int i = 0; i < SCCL_MAX_DIRECT_ARITY; i++)
            channel->collnetDirect.up[i] = -1;
        for(int i = 0; i < SCCL_MAX_DIRECT_ARITY; i++)
            channel->collnetDirect.down[i] = -1;

        int* ringIntra    = graphs[SCCL_ALGO_RING]->intra + c * localRanks;
        int* treeIntra    = graphs[SCCL_ALGO_TREE]->intra + c * localRanks;
        int* collNetIntra = graphs[SCCL_ALGO_COLLNET_CHAIN]->intra + c * localRanks;
        int* nvlsIntra    = graphs[SCCL_ALGO_NVLS]->intra + c * localRanks;

        for(int i = 0; i < localRanks; i++) {
            if(ringIntra[i] == rank) {
                topoRanks->ringRecv[c] = ringIntra[0];
                topoRanks->ringSend[c] = ringIntra[localRanks - 1];
                channel->ring.prev     = (i == 0) ? -1 : ringIntra[i - 1];
                channel->ring.next     = (i == localRanks - 1) ? -1 : ringIntra[i + 1];
            }
            if(treeIntra[i] == rank) {
                int parentIndex = 0;
                int child0Index = graphs[SCCL_ALGO_TREE]->pattern == SCCL_TOPO_PATTERN_TREE ? 0 : 1;
                int child1Index = graphs[SCCL_ALGO_TREE]->pattern == SCCL_TOPO_PATTERN_SPLIT_TREE ? 1 : 0;

                topoRanks->treeToParent[c] = treeIntra[parentIndex];
                topoRanks->treeToChild0[c] = treeIntra[child0Index];
                topoRanks->treeToChild1[c] = treeIntra[child1Index];
                channel->tree.up           = i == 0 ? -1 : treeIntra[i - 1];
                channel->tree.down[0]      = i == localRanks - 1 ? -1 : treeIntra[i + 1];
            }
            if(collNetIntra[i] == rank) {
                channel->collnetChain.up      = i == 0 ? comm->nRanks : collNetIntra[i - 1];
                channel->collnetChain.down[0] = i == localRanks - 1 ? -1 : collNetIntra[i + 1];
            }
        }
        topoRanks->ringPrev[c]  = channel->ring.prev;
        topoRanks->ringNext[c]  = channel->ring.next;
        topoRanks->nvlsHeads[c] = nvlsIntra[0];
    }
    // Duplicate channels rings/trees
    struct scclChannel* channel0 = comm->channels;
    struct scclChannel* channel1 = (nChannels > MAXCHANNELS / 2) ? 0 : channel0 + nChannels;
    if(channel1)
        memcpy(channel1, channel0, nChannels * sizeof(struct scclChannel));
    return scclSuccess;
}

bool isRankHere(const char* s, int start, int end, int rank) {
    if(end <= start || start < 0 || end < 0)
        return false;
    int num = 0;
    while(start < end) {
        char currChar = s[start];
        if(isdigit(currChar)) {
            num = num * 10 + (currChar - '0');
            if(isdigit(s[start + 1])) {
                start++;
                continue;
            }
        } else if(currChar == '(' || currChar == ')') {
            start++;
            num = 0;
            continue;
        }
        if(num == rank)
            return true;
        start++;
    }
    return false;
}

scclResult_t scclTreeBasePostset(struct scclComm* comm, struct scclTopoGraph* treeGraph) {
    int x = 0, y = 0;
    for(int i = 0; treeGraph->treeBase[i][0] != 0; i++) {
        x = i + 1;
    }
    if(treeGraph->treeBase[0][0] == 0)
        return scclSuccess;
    int nChannels  = comm->nChannels;
    int localRanks = comm->topo->nodes[GPU].count;
    // new tree
    for(int c = 0; c < nChannels; c++) { // in here
        int buff = c % x;
        char tempString[SCCL_TOPO_MAX_NODES * 4];
        int ko = 0;
        while(treeGraph->treeBase[buff][ko] != 0) {
            tempString[ko] = treeGraph->treeBase[buff][ko];
            ko++;
        }
        tempString[ko]              = 0;
        int start                   = 0;
        int curRank                 = comm->rank;
        struct scclChannel* channel = comm->channels + c;
        int end                     = 0;
        while(tempString[end] != 0)
            end++;
        int parent = -1;
        // constructing a number from the continuous digits
        while(start < end) {
            int num = 0, num_found = 0;
            start++;
            while(start < end && tempString[start] != '(' && tempString[start] != ')') {
                int num_here = (int)(tempString[start] - '0');
                num          = num * 10 + num_here;
                start        = start + 1;
                if(tempString[start] == '(' || tempString[start] == ')' || start == end)
                    num_found = 1;
            }
            if(num_found != 0 && num == curRank) {
                channel->tree.up = parent;
                int depth        = 0;
                for(int childId = 0; childId < SCCL_MAX_TREE_ARITY; childId++) {
                    int or_start                = start;
                    int child                   = -1;
                    channel->tree.down[childId] = -1;
                    if(or_start >= end - 1)
                        continue;
                    num = 0;
                    or_start++;
                    while(tempString[or_start] != 0 && tempString[or_start] != '(' && tempString[or_start] != ')') {
                        int num_here = (int)(tempString[or_start] - '0');
                        num          = num * 10 + num_here;
                        or_start++;
                    }
                    child = num;
                    // find next child start
                    while(start < end) {
                        if(tempString[start] == '(')
                            depth++;
                        else if(tempString[start] == ')')
                            depth--;
                        if(depth == 0)
                            break; // next child
                        start++;
                    }
                    start++;
                    channel->tree.down[childId] = child;
                    // get kids, update numbers, get out of this string
                }
                break;
            } else { // go to the next one
                parent      = num;
                int start_c = start;
                int end_c   = start_c;
                while(end_c < end) {
                    int depth = 0;
                    while(end_c < end) {
                        if(tempString[end_c] == '(')
                            depth++;
                        else if(tempString[end_c] == ')')
                            depth--;
                        if(depth == 0)
                            break; // next child
                        end_c++;
                    }
                    if(isRankHere(tempString, start_c, end_c, curRank)) {
                        start = start_c;
                        end   = end_c;
                        break;
                    } else {
                        end_c++;
                        start_c = end_c;
                    }
                }
            }
        }
    }
    return scclSuccess;
}

static scclResult_t connectRings(struct scclComm* comm, int* ringRecv, int* ringSend, int* ringPrev, int* ringNext) {
    int nChannels = comm->nChannels;
    int nNodes    = comm->nNodes;
    for(int c = 0; c < nChannels; c++) {
        int* recv                    = ringRecv + c * comm->nNodes;
        int* send                    = ringSend + c * comm->nNodes;
        int* prev                    = ringPrev + c * comm->nRanks;
        int* next                    = ringNext + c * comm->nRanks;
        struct scclChannel* channel0 = comm->channels + c;
        struct scclChannel* channel1 = (nChannels > MAXCHANNELS / 2) ? 0 : channel0 + nChannels;
        for(int n = 0; n < nNodes; n++) {
            int recvRank     = recv[n];
            int prevSendRank = send[(n - 1 + nNodes) % nNodes];
            prev[recvRank]   = prevSendRank;
            if(comm->rank == recvRank) {
                channel0->ring.prev = prevSendRank;
                if(channel1)
                    channel1->ring.prev = prevSendRank;
            }
            int sendRank     = send[n];
            int nextRecvRank = recv[(n + 1) % nNodes];
            next[sendRank]   = nextRecvRank;
            if(comm->rank == sendRank) {
                channel0->ring.next = nextRecvRank;
                if(channel1)
                    channel1->ring.next = nextRecvRank;
            }
        }
    }
    return scclSuccess;
}

static scclResult_t getIndexes(int* ranks, int* indexes, int nNodes) {
    for(int n = 0; n < nNodes; n++)
        indexes[n] = ranks[n];
    return scclSuccess;
}

static scclResult_t setTreeUp(struct scclTree* tree, int* indexes, int u) {
    if(u == -1)
        return scclSuccess;
    tree->up = indexes[u];
    return scclSuccess;
}

static scclResult_t setTreeDown(struct scclTree* tree, int* indexes, int d) {
    if(d == -1)
        return scclSuccess;
    int x = 0;
    while(x < SCCL_MAX_TREE_ARITY && tree->down[x] >= 0)
        x++;
    if(x == SCCL_MAX_TREE_ARITY) {
        WARN("Internal error : tree already has %d children (%d %d %d)", x, tree->down[0], tree->down[1], tree->down[2]);
        return scclInternalError;
    }
    tree->down[x] = indexes[d];
    return scclSuccess;
}

static scclResult_t connectTrees(struct scclComm* comm, int* treeToParent, int* treeToChild0, int* treeToChild1, int* treePatterns) {
    const int nChannels = (comm->nChannels > MAXCHANNELS / 2) ? comm->nChannels / 2 : comm->nChannels, nNodes = comm->nNodes, node = comm->node;

    // Compute tree depth. Not an exact value but a good approximation in most
    // cases
    int depth = comm->nRanks / nNodes - 1 + log2i(nNodes);

    int t0u, t0d0, t0d1, t0ChildType, t1u, t1d0, t1d1, t1ChildType;
    int *ttp, *ttc0, *ttc1;
    SCCLCHECK(scclGetDtree(nNodes, node, &t0u, &t0d0, &t0d1, &t0ChildType, &t1u, &t1d0, &t1d1, &t1ChildType));
    if(comm->nChannels <= MAXCHANNELS / 2) {
        for(int c = 0; c < nChannels; c++) {
            struct scclChannel* channel0 = comm->channels + c;
            struct scclChannel* channel1 = channel0 + nChannels;
            ttp                          = treeToParent + c * comm->nNodes;
            ttc0                         = treeToChild0 + c * comm->nNodes;
            ttc1                         = treeToChild1 + c * comm->nNodes;
            if(comm->rank == ttp[node]) {
                SCCLCHECK(setTreeUp(&channel0->tree, t0ChildType == 0 ? ttc0 : ttc1, t0u));
                SCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ttc0 : ttc1, t1u));
            }
            if(comm->rank == ttc0[node]) {
                SCCLCHECK(setTreeDown(&channel0->tree, ttp, t0d0));
                SCCLCHECK(setTreeDown(&channel1->tree, ttp, t1d0));
            }
            if(comm->rank == ttc1[node]) {
                SCCLCHECK(setTreeDown(&channel0->tree, ttp, t0d1));
                SCCLCHECK(setTreeDown(&channel1->tree, ttp, t1d1));
            }
            if(comm->rank == ttp[node] || comm->rank == ttc0[node] || comm->rank == ttc1[node]) {
                INFO(SCCL_LOG_TOPO,
                     "Tree %d : %d -> %d -> %d/%d/%d",
                     c,
                     channel0->tree.up,
                     comm->rank,
                     channel0->tree.down[0],
                     channel0->tree.down[1],
                     channel0->tree.down[2]);
                INFO(SCCL_LOG_TOPO,
                     "Tree %d : %d -> %d -> %d/%d/%d",
                     c + nChannels,
                     channel1->tree.up,
                     comm->rank,
                     channel1->tree.down[0],
                     channel1->tree.down[1],
                     channel1->tree.down[2]);
            }
            channel0->tree.depth = channel1->tree.depth = depth;
        }
    } else {
        for(int c = 0; c < nChannels; c++) {
            struct scclChannel* channel0 = comm->channels + c;
            ttp                          = treeToParent + c * comm->nNodes;
            ttc0                         = treeToChild0 + c * comm->nNodes;
            ttc1                         = treeToChild1 + c * comm->nNodes;
            if(comm->rank == ttp[node]) {
                SCCLCHECK(setTreeUp(&channel0->tree, t0ChildType == 0 ? ttc0 : ttc1, t0u));
            }
            if(comm->rank == ttc0[node]) {
                SCCLCHECK(setTreeDown(&channel0->tree, ttp, t0d0));
            }
            if(comm->rank == ttc1[node]) {
                SCCLCHECK(setTreeDown(&channel0->tree, ttp, t0d1));
            }
            if(comm->rank == ttp[node] || comm->rank == ttc0[node] || comm->rank == ttc1[node]) {
                INFO(SCCL_LOG_TOPO,
                     "Tree %d : %d -> %d -> %d/%d/%d",
                     c,
                     channel0->tree.up,
                     comm->rank,
                     channel0->tree.down[0],
                     channel0->tree.down[1],
                     channel0->tree.down[2]);
            }
            channel0->tree.depth = depth;
        }
        for(int c = nChannels; c < nChannels * 2; c++) {
            struct scclChannel* channel1 = comm->channels + c;
            ttp                          = treeToParent + c * comm->nNodes;
            ttc0                         = treeToChild0 + c * comm->nNodes;
            ttc1                         = treeToChild1 + c * comm->nNodes;
            if(comm->rank == ttp[node]) {
                SCCLCHECK(setTreeUp(&channel1->tree, t1ChildType == 0 ? ttc0 : ttc1, t1u));
            }
            if(comm->rank == ttc0[node]) {
                SCCLCHECK(setTreeDown(&channel1->tree, ttp, t1d0));
            }
            if(comm->rank == ttc1[node]) {
                SCCLCHECK(setTreeDown(&channel1->tree, ttp, t1d1));
            }
            if(comm->rank == ttp[node] || comm->rank == ttc0[node] || comm->rank == ttc1[node]) {
                INFO(SCCL_LOG_TOPO,
                     "Tree %d : %d -> %d -> %d/%d/%d",
                     c + nChannels,
                     channel1->tree.up,
                     comm->rank,
                     channel1->tree.down[0],
                     channel1->tree.down[1],
                     channel1->tree.down[2]);
            }
            channel1->tree.depth = depth;
        }
    }
    return scclSuccess;
}

static scclResult_t connectCollNet(struct scclComm* comm, struct scclTopoGraph* collNetGraph) {
    int rank       = comm->rank;
    int localRanks = comm->localRanks;
    int nHeads     = 0;
    int* heads;
    SCCLCHECK(scclCalloc(&heads, localRanks));
    // Find all head ranks
    // Head index is always 0
    for(int c = 0; c < collNetGraph->nChannels; c++) {
        int* collNetIntra = collNetGraph->intra + c * localRanks;
        int head          = collNetIntra[0];
        for(int h = 0; h < nHeads; h++)
            if(heads[h] == head)
                head = -1;
        if(head != -1)
            heads[nHeads++] = collNetIntra[0];
    }
    // For all channels
    for(int c = 0; c < comm->nChannels; c++) {
        struct scclChannel* channel = comm->channels + c;
        char line[1024];
        sprintf(line, "CollNet channel %d rank %d ", c, rank);
        int nDown = 0;
        for(int i = 0; i < nHeads; i++) {
            if(rank == heads[i]) {                              // is head
                channel->collnetDirect.headRank = i;            // Mark the index for deciding offset in the CUDA kernel
                channel->collnetDirect.out      = comm->nRanks; // Set root of collnetDirect to id nranks
                int* collNetIntra               = collNetGraph->intra + i * localRanks;
                sprintf(line + strlen(line), "down ");
                for(int r = 0; r < localRanks; r++) {
                    if(collNetIntra[r] == rank)
                        continue;
                    channel->collnetDirect.down[nDown++] = collNetIntra[r]; // connect to all peers
                    sprintf(line + strlen(line), " %d ", collNetIntra[r]);
                }
                sprintf(line + strlen(line), "nDown %d ", nDown);
                break;
            }
        }
        // Connect to all heads
        int nUp = 0;
        sprintf(line + strlen(line), "up ");
        for(int h = 0; h < nHeads; h++) {
            if(rank == heads[h])
                continue;
            channel->collnetDirect.up[nUp++] = heads[h];
            sprintf(line + strlen(line), " %d ", heads[h]);
        }
        channel->collnetDirect.nHeads = nHeads;
        channel->collnetDirect.shift  = (rank % localRanks) % nHeads; // Shift by intraRank so that leaves don't send to same head simultaneously
        channel->collnetDirect.depth  = (nUp == 0 && nDown == 0) ? 1 : 2;
        sprintf(line + strlen(line), "nUp %d nHeads %d ", nUp, nHeads);
        sprintf(line + strlen(line), "headRank %d out %d shift %d", channel->collnetDirect.headRank, channel->collnetDirect.out, channel->collnetDirect.shift);
        INFO(SCCL_LOG_TOPO, "%s", line);
        channel->collnetChain.depth = comm->nRanks / comm->nNodes;
    }
    for(int c = 0; c < comm->nvlsChannels; c++) {
        struct scclChannel* channel = comm->channels + c;
        if(channel->nvls.headRank != -1)
            channel->nvls.out = comm->nRanks;
    }
    free(heads);
    return scclSuccess;
}

static scclResult_t connectNvls(struct scclComm* comm, int* nvlsHeads, struct scclTopoGraph* nvlsGraph) {
    int nHeads   = nvlsGraph->nChannels;
    int headRank = -1;
    for(int h = 0; h < nHeads; h++) {
        if(nvlsGraph->intra[h * comm->localRanks] == comm->rank)
            headRank = h;
    }

    if(nHeads == 0) {
        comm->nvlsChannels = 0;
        return scclSuccess;
    }

    for(int c = 0; c < comm->nvlsChannels; c++) {
        struct scclChannel* channel = comm->channels + c;
        channel->nvls.nHeads        = nHeads;
        for(int h = 0; h < nHeads; h++)
            channel->nvls.up[h] = comm->nRanks + 1 + h;
        for(int h = nHeads; h < SCCL_MAX_NVLS_ARITY; h++)
            channel->nvls.up[h] = -1;
        channel->nvls.down     = comm->nRanks + 1 + headRank;
        channel->nvls.out      = -1; // NVLS+SHARP not yet implemented.
        channel->nvls.headRank = headRank;
        channel->nvls.treeUp = channel->nvls.treeDown[0] = channel->nvls.treeDown[1] = channel->nvls.treeDown[2] = -1;
        channel->nvls.node                                                                                       = comm->node;
        channel->nvls.nNodes                                                                                     = comm->nNodes;
    }
    if(comm->nNodes == 1)
        return scclSuccess;

    // Connect Trees
    int tree0Parent, tree0Child0, tree0Child1, tree1Parent, tree1Child0, tree1Child1;
    int pc0, pc1; // ignored
    SCCLCHECK(scclGetDtree(comm->nNodes, comm->node, &tree0Parent, &tree0Child0, &tree0Child1, &pc0, &tree1Parent, &tree1Child0, &tree1Child1, &pc1));

    int* heads       = NULL;
    int treeUp[2]    = {-1, -1};
    int treeDown0[2] = {-1, -1};
    int treeDown1[2] = {-1, -1};

    if(comm->node == 0) {
        for(int h = 0; h < nHeads; h++) {
            char line[1024];
            sprintf(line, "NVLS Head %2d:", h);
            heads = nvlsHeads + h * comm->nNodes;
            for(int n = 0; n < comm->nNodes && n < 20; n++) {
                sprintf(line + strlen(line), " %2d", heads[n]);
            }
            INFO(SCCL_INIT, "%s", line);
        }
    }

    // Find the heads where I'm the head rank and retain tree up/down
    for(int h = 0; h < nHeads; h++) {
        heads = nvlsHeads + h * comm->nNodes;
        if(heads[comm->node] == comm->rank) {
            treeUp[0]    = tree0Parent == -1 ? -1 : heads[tree0Parent];
            treeDown0[0] = tree0Child0 == -1 ? -1 : heads[tree0Child0];
            treeDown1[0] = tree0Child1 == -1 ? -1 : heads[tree0Child1];
            treeUp[1]    = tree1Parent == -1 ? -1 : heads[tree1Parent];
            treeDown0[1] = tree1Child0 == -1 ? -1 : heads[tree1Child0];
            treeDown1[1] = tree1Child1 == -1 ? -1 : heads[tree1Child1];
            break;
        }
    }
    // Set prev/next in all channels (NVLS compute channels work
    // orthogonally to NVLS search channels).
    for(int c = 0; c < comm->nvlsChannels; c++) {
        struct scclChannel* channel = comm->channels + c;
        channel->nvls.treeUp        = treeUp[c % 2];
        channel->nvls.treeDown[0]   = channel->nvls.down;
        int ix                      = 1;
        if(treeDown0[c % 2] != -1)
            channel->nvls.treeDown[ix++] = treeDown0[c % 2];
        if(treeDown1[c % 2] != -1)
            channel->nvls.treeDown[ix] = treeDown1[c % 2];
    }

    struct scclNvls* nvls0 = &comm->channels[0].nvls;
    struct scclNvls* nvls1 = &comm->channels[1].nvls;
    INFO(SCCL_LOG_TOPO,
         "NVLS Trees : %d/%d->%d->%d %d/%d->%d->%d",
         nvls0->treeDown[0],
         nvls0->treeDown[1],
         comm->rank,
         nvls0->treeUp,
         nvls1->treeDown[0],
         nvls1->treeDown[1],
         comm->rank,
         nvls1->treeUp);
    return scclSuccess;
}

// Legacy naming
SCCL_PARAM(MinNrings, "MIN_NRINGS", -2);
SCCL_PARAM(MaxNrings, "MAX_NRINGS", -2);
// New naming
SCCL_PARAM(MinNchannels, "MIN_NCHANNELS", 4);
SCCL_PARAM(MaxNchannels, "MAX_NCHANNELS", -2);

int scclMinNchannels() {
    int minNchannels = 2;
    if(scclParamMinNrings() != -2)
        minNchannels = scclParamMinNrings();
    if(scclParamMinNchannels() != -2)
        minNchannels = scclParamMinNchannels();
    if(minNchannels > MAXCHANNELS) {
        WARN("User asked for a minimum of %d channels, limiting to %d", minNchannels, MAXCHANNELS);
        minNchannels = MAXCHANNELS;
    }
    if(minNchannels < 0)
        minNchannels = 0;
    return minNchannels;
}
int scclMaxNchannels() {
    int maxNchannels = MAXCHANNELS;
    if(scclParamMaxNrings() != -2)
        maxNchannels = scclParamMaxNrings();
    if(scclParamMaxNchannels() != -2)
        maxNchannels = scclParamMaxNchannels();
    if(maxNchannels > MAXCHANNELS)
        maxNchannels = MAXCHANNELS;
    if(maxNchannels < 1) {
        WARN("User asked for a maximum of %d channels, setting it to 1", maxNchannels);
        maxNchannels = 1;
    }
    return maxNchannels;
}

static int copyChannels(struct scclComm* comm, int start, int end, int* ringPrev, int* ringNext) {
    int nranks = comm->nRanks;
    int c;
    for(c = start; c < end; c++) {
        memcpy(ringPrev + c * nranks, ringPrev + (c - start) * nranks, nranks * sizeof(int));
        memcpy(ringNext + c * nranks, ringNext + (c - start) * nranks, nranks * sizeof(int));
        memcpy(comm->channels + c, comm->channels + c - start, sizeof(struct scclChannel));
    }
    return c;
}
static int copyMixedChannels(struct scclComm* comm, int start, int end, int* ringPrev, int* ringNext) {
    int nranks = comm->nRanks;
    int c;
    for(c = start; c < end; c++) {
        memcpy(ringPrev + c * nranks, ringPrev + (c - start) * nranks, nranks * sizeof(int));
        memcpy(ringNext + c * nranks, ringNext + (c - start) * nranks, nranks * sizeof(int));
        memcpy(comm->channels + c, comm->channels + c - start, sizeof(struct scclChannel));
        comm->channels[c].transportType = comm->mixedTransportType;
    }
    return c;
}

RCCL_PARAM(MaxMixedHylinkNChannels, "MAX_MIXED_HYLINK_NCHANNELS", 0);
RCCL_PARAM(MixedTransportType, "MIXED_TRANSPORT_TYPE", TRANSPORT_SHM);

scclResult_t scclTopoPostset(
    struct scclComm* comm, int* firstRanks, int* treePatterns, struct scclTopoRanks** allTopoRanks, int* rings, struct scclTopoGraph** graphs, int nc) {
    // Gather data from all ranks
    int *ringRecv, *ringSend, *ringPrev, *ringNext, *treeToParent, *treeToChild0, *treeToChild1, *nvlsHeads;
    int nranks       = comm->nRanks;
    int nNodes       = comm->nNodes;
    int nChannels    = comm->nChannels;
    int MinNChannels = scclMinNchannels();
    int MaxNChannels = scclMaxNchannels();
    SCCLCHECK(scclCalloc(&ringRecv, nNodes * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&ringSend, nNodes * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&ringPrev, nranks * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&ringNext, nranks * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&treeToParent, nNodes * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&treeToChild0, nNodes * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&treeToChild1, nNodes * MAXCHANNELS));
    SCCLCHECK(scclCalloc(&nvlsHeads, nNodes * MAXCHANNELS));
    for(int c = 0; c < nChannels; c++) {
        for(int n = 0; n < nNodes; n++) {
            int r                        = firstRanks[n];
            ringRecv[c * nNodes + n]     = allTopoRanks[r]->ringRecv[c];
            ringSend[c * nNodes + n]     = allTopoRanks[r]->ringSend[c];
            treeToParent[c * nNodes + n] = allTopoRanks[r]->treeToParent[c];
            treeToChild0[c * nNodes + n] = allTopoRanks[r]->treeToChild0[c];
            treeToChild1[c * nNodes + n] = allTopoRanks[r]->treeToChild1[c];
            nvlsHeads[c * nNodes + n]    = allTopoRanks[r]->nvlsHeads[c];
        }
        for(int r = 0; r < nranks; r++) {
            ringPrev[c * nranks + r] = allTopoRanks[r]->ringPrev[c];
            ringNext[c * nranks + r] = allTopoRanks[r]->ringNext[c];
        }
    }

    // Connect rings and trees. This should also duplicate the channels.
    SCCLCHECK(connectRings(comm, ringRecv, ringSend, ringPrev, ringNext));
    SCCLCHECK(connectTrees(comm, treeToParent, treeToChild0, treeToChild1, treePatterns));
    SCCLCHECK(connectNvls(comm, nvlsHeads, graphs[SCCL_ALGO_NVLS]));

    // Duplicate ringPrev/ringNext for scclBuildRing
    if(nChannels <= MAXCHANNELS / 2)
        memcpy(ringPrev + nChannels * nranks, ringPrev, nChannels * nranks * sizeof(int));
    if(nChannels <= MAXCHANNELS / 2)
        memcpy(ringNext + nChannels * nranks, ringNext, nChannels * nranks * sizeof(int));

    if(scclTopoPathAllNVLink(comm->topo) == 1 && getenv("SCCL_MIN_NCHANNELS") == NULL)
        MinNChannels = 32;
    if(scclTopoPathAllNVLink(comm->topo) == 1 && getenv("SCCL_MAX_NCHANNELS") == NULL)
        MaxNChannels = 32;

#ifdef HCU_SDMA_FEATURE
    int ncSdma = nc;
    ncSdma     = std::min((int)scclMaxNchannels() / comm->nChannels, nc);
    ncSdma *= comm->nChannels;
#endif

    // Get number of channels after duplication
    nc = std::min((int)MaxNChannels / comm->nChannels, nc);
    nc *= comm->nChannels;

    // Duplication should be complete now
    nChannels = comm->nChannels = std::min(MAXCHANNELS, (nChannels <= MAXCHANNELS / 2) ? nChannels * 2 : nChannels);

    // Setup CollNet
    if(comm->collNetSupport == 1) {
        struct scclTopoGraph* collNetGraph = graphs[SCCL_ALGO_COLLNET_DIRECT];
        // Add more channels to saturate intra-node bandwidth, except the 1 PPN case
        if(collNetGraph->bwIntra > collNetGraph->bwInter && comm->nRanks > comm->nNodes) {
            int collNetNchannels = std::min(MAXCHANNELS, nChannels + nChannels / 2);
            nChannels = comm->nChannels = copyChannels(comm, nChannels, collNetNchannels, ringPrev, ringNext);
        }
        SCCLCHECK(connectCollNet(comm, collNetGraph));
    }

    // Use 4 compute channels per search channel to reach peak BW on <8 PPN
    if(comm->minCompCap == 90 && comm->nNodes > 1 && graphs[SCCL_ALGO_RING]->bwIntra > 45.0 && 2 * nChannels <= MAXCHANNELS) {
        nChannels = comm->nChannels = copyChannels(comm, nChannels, 2 * nChannels, ringPrev, ringNext);
    }

    // Add Hylink + PCIE double channel path
    if(graphs[SCCL_ALGO_RING]->typeIntra == PATH_NVL) {
        comm->nMixedHylinkChannels = std::min(MAXCHANNELS - comm->nChannels, (int)rcclParamMaxMixedHylinkNChannels());
        if(comm->nMixedHylinkChannels > 0) {
            INFO(SCCL_LOG_TOPO,
                 "<%s:%d> -----> comm->nMixedHylinkShmChannels: %d, comm->nChannels: %d\n",
                 __func__,
                 __LINE__,
                 comm->nMixedHylinkChannels,
                 comm->nChannels);
            comm->mixedTransportType = std::max((int)rcclParamMixedTransportType(), TRANSPORT_SHM);
            nChannels = comm->nChannels = copyMixedChannels(comm, nChannels, nChannels + comm->nMixedHylinkChannels, ringPrev, ringNext);
        }
    }

    // Honor SCCL_MIN_NRINGS/SCCL_MAX_NRINGS.
    // We permit combining max, then min, to only use the first channels, then duplicate them.
    if(checkSdmaCopyEnable(comm)) {
        uint32_t sdmaChannelNum;
        uint32_t maxChannels;
        sdmaChannelNum = getSdmaChannelNum(comm);
        if(comm->sharedRes->owner != comm) {
            /* child comm #channels cannot exceed top parent #channels. */
            nChannels = comm->nChannels = std::min(std::min(std::min(scclMaxNchannels(), nChannels), comm->config.maxCTAs), comm->sharedRes->tpNChannels);
            maxChannels =
                sdmaChannelNum ? sdmaChannelNum : std::min(std::max(scclMinNchannels(), std::max(ncSdma, comm->config.minCTAs)), comm->sharedRes->tpNChannels);
            nChannels = comm->nChannels = copyChannels(comm, nChannels, maxChannels, ringPrev, ringNext);
        } else {
            nChannels = comm->nChannels = std::min(std::min(scclMaxNchannels(), nChannels), comm->config.maxCTAs);
            maxChannels                 = sdmaChannelNum ? sdmaChannelNum : std::max(scclMinNchannels(), std::max(ncSdma, comm->config.minCTAs));
            nChannels = comm->nChannels = copyChannels(comm, nChannels, maxChannels, ringPrev, ringNext);
        }
        INFO(SCCL_INIT, "-hcugon- scclTopoPostset rank %d sdmaChannelNum %d nChannels %d", comm->rank, sdmaChannelNum, comm->nChannels);
    } else {
        if(comm->sharedRes->owner != comm) {
            /* child comm #channels cannot exceed top parent #channels. */
            nChannels = comm->nChannels = std::min(std::min(std::min(MaxNChannels, nChannels), comm->config.maxCTAs), comm->sharedRes->tpNChannels);
            nChannels = comm->nChannels = copyChannels(
                comm, nChannels, std::min(std::max(MinNChannels, std::max(nc, comm->config.minCTAs)), comm->sharedRes->tpNChannels), ringPrev, ringNext);
        } else {
            nChannels = comm->nChannels = std::min(std::min(MaxNChannels, nChannels), comm->config.maxCTAs);
            nChannels = comm->nChannels = copyChannels(comm, nChannels, std::max(MinNChannels, std::max(nc, comm->config.minCTAs)), ringPrev, ringNext);
        }
    }
    // Create rings array and check all is fine
    SCCLCHECK(scclBuildRings(nChannels, rings, comm->rank, comm->nRanks, ringPrev, ringNext));

    free(ringRecv);
    free(ringSend);
    free(ringPrev);
    free(ringNext);
    free(treeToParent);
    free(treeToChild0);
    free(treeToChild1);
    free(nvlsHeads);

    return scclSuccess;
}

} // namespace detect
} // namespace topology
} // namespace hardware
} // namespace sccl
