#include <unordered_set>
#include "paths.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace graph {

PathFinder::PathFinder(const BootstrapComm_t* bootstrap_comm)
    : rank(bootstrap_comm->rank),
      nRanks(bootstrap_comm->nRanks),
      localRank(bootstrap_comm->localRank),
      nLocalRanks(bootstrap_comm->nLocalRanks),
      interRank(bootstrap_comm->interRank),
      nInterRanks(bootstrap_comm->nInterRanks),
      node_container_(bootstrap_comm->rank_phys_set->node_info_vec.data(),
                      bootstrap_comm->nRanks * bootstrap_comm->rank_phys_set->node_info_total_bytes) { // 初始化NodeContainer对象
    printf("get PathFinder, node_container_=%zu\n", node_container_.size());
    for(size_t i = 0; i < node_container_.size(); ++i) {
        scclTopoNode_t* node = node_container_[i];
        // 检查node是否有效
        if(node->type > CPU) {
            // 将有效的node添加到graph_nodes中，并保存其neighbor的id
            graph_node_neighbors_[node->id] = std::vector<uint64_t>(node->neighbors.begin(), node->neighbors.begin() + node->neighborCount);
            // 构建id到index的映射
            id_to_index_[node->id] = i;
        }
    }

#if 0
    if(rank == 0) {
        // 遍历id_to_index_进行打印
        for(const auto& pair : id_to_index_) {
            uint64_t node_id           = pair.first;
            size_t index               = pair.second;
            const scclTopoNode_t* node = node_container_[index];

            int interRank, deviceValue, terminalType, hipDev, numaId;
            bootstrap::physical_links::getIdComponents(node_id, &interRank, &deviceValue, &terminalType, &hipDev, &numaId);
            char busIdStr[17];
            int64ToBusId(node->busId, busIdStr);
            printf("rank=%d, node=(InterRank:%d, V:%d, T:%d, H:%d, N:%d, type:%d, busIdStr:%s), neighbor_count=%zu",
                   rank,
                   interRank,
                   deviceValue,
                   terminalType,
                   hipDev,
                   numaId,
                   node->type,
                   busIdStr,
                   node->neighborCount);

            for(int n = 0; n < node->neighborCount; ++n) {
                uint64_t neighbor_id                = node->neighbors[n];
                const scclTopoNode_t* neighbor_node = findNodeById(neighbor_id);
                if(neighbor_node) {
                    bootstrap::physical_links::getIdComponents(neighbor_id, &interRank, &deviceValue, &terminalType, &hipDev, &numaId);
                    int64ToBusId(neighbor_node->busId, busIdStr);

                    printf(", neighbor[%d]=(InterRank:%d, V:%d, T:%d, H:%d, N:%d, type:%d, busIdStr:%s)",
                           n,
                           interRank,
                           deviceValue,
                           terminalType,
                           hipDev,
                           numaId,
                           neighbor_node->type,
                           busIdStr);
                } else {
                    printf(", neighbor[%d]=unknown", n);
                }
            }
            printf("\n");
        }
    }
#endif

    // 查找当前rank对应的其他GPU节点的所有路径
    printf("PathFinder pos 1\n");
    findGpuPaths();
    printf("PathFinder pos 2\n");
}

/**
 * @brief 计算拓扑图中GPU节点之间的点对点映射
 *
 * 该函数用于计算拓扑图中GPU节点之间的点对点映射。它遍历`gpu_paths_`中的所有路径，
 * 对于每一条路径，它将路径中的每个节点添加到`topo_graph`的`graph_nodes`中。然后，它根据路径中途径的节点点确定连接方式的类型，
 * 并将连接方式的类型存储在`topo_graph`的`transport_map`中。最后，它将路径添加到`topo_graph`的`gpu_paths`中。
 *
 * @param topo_graph 指向拓扑图的指针
 * @return scclResult_t 计算结果
 */
scclResult_t PathFinder::computeTopoGpuP2pMap(scclTopoGraph_t* topo_graph) {
    // 遍历gpu_paths_中的所有路径
    for(const auto& start_node_pair : gpu_paths_) {
        uint64_t start_node_id = start_node_pair.first;
        const auto& paths      = start_node_pair.second;

        // 遍历从start_node_id到其他GPU节点的所有路径
        for(const auto& path : paths) {
#if 0
            printf("paths len=%zu, path=%zu\n", paths.size(), path.size());
#endif
            if(path.size() == 0)
                continue;

            // 遍历路径中的每个节点，将其添加到graph_nodes中
            for(uint64_t node_id : path) {
                // 查找node_container_中对应的节点
                const scclTopoNode_t* node = findNodeById(node_id);
                if(node != nullptr) {
                    // 检查节点是否已经存在于graph_nodes中
                    auto it = topo_graph->graph_nodes.find(node_id);
                    if(it == topo_graph->graph_nodes.end()) {
                        // 如果节点不存在于graph_nodes中，则将其拷贝到graph_nodes
                        topo_graph->graph_nodes[node_id] = *node;
                    }
                }
            }

            // 将路径添加到topo_graph中的gpu_paths
            uint64_t end_node_id = path.back(); // 获取路径的最后一个节点的ID

            // 记录bitmap
            LinkType_t link_type;
            int start_gpu_rank, end_gpu_rank;
            {
                // 根据路径中途径的节点点确定连接方式的类型
                SCCLCHECK(determineLinkType(path, &link_type));

                int start_interRank, start_hipDev;
                int end_interRank, end_hipDev;
                bootstrap::physical_links::getIdComponents(start_node_id, &start_interRank, nullptr, nullptr, &start_hipDev, nullptr);
                bootstrap::physical_links::getIdComponents(end_node_id, &end_interRank, nullptr, nullptr, &end_hipDev, nullptr);

                start_gpu_rank = start_interRank * nLocalRanks + start_hipDev;
                end_gpu_rank   = end_interRank * nLocalRanks + end_hipDev;
#if 0
                printf("rank=%d, interRank=%d, localRank=%d: start_interRank=%d, start_hipDev=%d, end_interRank=%d, end_hipDev=%d, link_type=%d\n",
                       rank,
                       interRank,
                       localRank,
                       start_interRank,
                       start_hipDev,
                       end_interRank,
                       end_hipDev,
                       static_cast<int>(link_type));
#endif
            }

            // 将连接方式的类型存储在transport_map中
            if(*(topo_graph->getTransportMapData(start_gpu_rank, end_gpu_rank)) > 0 && link_type > 0) {
                if(link_type < static_cast<LinkType_t>(*(topo_graph->getTransportMapData(start_gpu_rank, end_gpu_rank)))) {
                    *(topo_graph->getTransportMapData(start_gpu_rank, end_gpu_rank)) = link_type;
                    // 清空之前的路径
                    topo_graph->gpu_paths[start_node_id][end_node_id].clear();
                    // 添加新的路径
                    topo_graph->gpu_paths[start_node_id][end_node_id].push_back(path);
                } else if(link_type == static_cast<LinkType_t>(*(topo_graph->getTransportMapData(start_gpu_rank, end_gpu_rank)))) {
                    // 添加新的路径
                    topo_graph->gpu_paths[start_node_id][end_node_id].push_back(path);
                }
            } else {
                *(topo_graph->getTransportMapData(start_gpu_rank, end_gpu_rank)) = static_cast<uint8_t>(link_type);
                // 添加新的路径
                topo_graph->gpu_paths[start_node_id][end_node_id].push_back(path);
            }
        }
    }

    return scclSuccess;
}

/////////////////////////////////////////////////////////////////////////////////////////////
/**
 * @brief 查找当前rank对应的其他GPU节点的所有路径
 *
 * 该函数用于查找当前rank对应的GPU节点的所有路径。它遍历`id_to_index_`中的所有节点ID和索引对，
 * 对于每一个节点，如果该节点是GPU类型，并且属于当前rank的进程，则调用`bfsFindGpuPaths`函数执行广度优先搜索（BFS），
 * 查找到其他所有GPU节点的路径。最后，如果当前rank为1，则调用`printGpuPaths`函数打印所有GPU路径。
 */
void PathFinder::findGpuPaths() {
    // 查找当前rank对应的GPU的node，并执行BFS搜索，查找到其他所有GPU node的路径
    for(const auto& pair : id_to_index_) {
        uint64_t id  = pair.first;
        size_t index = pair.second;

        // 定位到node
        scclTopoNode_t* node = node_container_[index];
        int nodeInterRank, nodeHipDev;
        bootstrap::physical_links::getIdComponents(node->id, &nodeInterRank, nullptr, nullptr, &nodeHipDev, nullptr);
        if(node->type == GPU && nodeInterRank == this->interRank && nodeHipDev == this->localRank) {
            // printf("bfsFindGpuPaths start_node_id=%lu, running\n", node->id);
            bfsFindGpuPaths(node->id);
        }
    }
#if 1
    if(rank == 1) {
        printGpuPaths();
    }
#endif
}

/**
 * @brief 根据节点ID查找节点
 *
 * 该函数接收一个节点ID，并在`node_container_`中查找具有该ID的节点。如果找到了具有指定ID的节点，则返回指向该节点的指针；否则返回`nullptr`。
 *
 * @param id 要查找的节点ID
 * @return 如果找到了具有指定ID的节点，则返回指向该节点的指针；否则返回`nullptr`
 */
const scclTopoNode_t* PathFinder::findNodeById(uint64_t id) const {
    // 使用id_to_index_映射查找具有指定id的节点的索引
    auto it = id_to_index_.find(id);
    if(it != id_to_index_.end()) {
        return node_container_[it->second];
    }
    return nullptr; // 如果未找到具有指定id的节点，则返回nullptr
}

// TODO: 当nRanks特别大时，可以考虑采用kernel实现
/**
 * @brief 使用广度优先搜索（BFS）查找从起始GPU节点到其他GPU节点的所有路径
 *
 * 1.该函数从指定的起始GPU节点开始，使用广度优先搜索算法查找所有能够到达的GPU节点，并记录从起始节点到每个目标GPU节点的所有路径。
 * 每条路径中的所有节点最多使用一次。
 *
 * 2.该函数还添加了一个限制，以防止在路径中出现`interRank`在变化后又变回来的情况。
 * 也就是说，如果路径从`interRank == 0`连接到`interRank == 1`的节点后，则不能再连接回`interRank == 0`。
 *
 * @param start_node_id 起始GPU节点的ID
 */
#if 1
void PathFinder::bfsFindGpuPaths(uint64_t start_node_id) {
    // 使用一个队列来存储当前路径
    std::queue<std::vector<uint64_t>> queue;
    // 使用一个unordered_map来存储每个node的所有最短路径
    std::unordered_map<uint64_t, std::vector<std::vector<uint64_t>>> shortest_paths;

    // 将起始节点加入队列
    queue.push({start_node_id});
    shortest_paths[start_node_id] = {{start_node_id}};

    // 当队列不为空时，继续搜索
    while(!queue.empty()) {
        // 从队列中取出一个路径
        auto path = queue.front();
        queue.pop();

        // 获取当前路径的最后一个节点的ID
        uint64_t nodeId = path.back();
        // 根据节点ID查找对应的节点
        const scclTopoNode_t* current_node = findNodeById(nodeId);
        if(current_node == nullptr) {
            continue;
        }

        // 如果当前节点是GPU节点且不是起始节点，则将当前路径加入结果
        if(current_node->type == GPU && nodeId != start_node_id) {
            int hipDev;
            bootstrap::physical_links::getIdComponents(current_node->id, nullptr, nullptr, nullptr, &hipDev, nullptr);
            // 仅当节点内的device id小于等于nLocalRanks时，才是有效GPU，才将路径加入结果
            if(hipDev < nLocalRanks) {
                gpu_paths_[start_node_id].push_back(path);
            }
        } else {
            int nodeInterRank;
            bootstrap::physical_links::getIdComponents(nodeId, &nodeInterRank);
            // 遍历当前节点的所有邻居节点
            for(uint64_t neighbor_id : graph_node_neighbors_.at(nodeId)) {
                if(findNodeById(neighbor_id) == nullptr) {
                    continue;
                }
                // 获取邻居节点的interRank
                int neighbor_inter_rank;
                bootstrap::physical_links::getIdComponents(neighbor_id, &neighbor_inter_rank);

                // 检查邻居节点是否已在当前路径中访问过
                bool visited = std::find(path.begin(), path.end(), neighbor_id) != path.end();
                // 检查interRank是否已经存在（仅当interRank改变时）
                bool inter_rank_exists = neighbor_inter_rank != nodeInterRank && std::find(path.begin(), path.end(), neighbor_id) != path.end();
                // 如果邻居节点未访问过且interRank未存在，则扩展路径
                if(!visited && !inter_rank_exists) {
                    std::vector<uint64_t> new_path = path;
                    new_path.push_back(neighbor_id);

                    // 如果新路径比已有的最短路径更短，或者长度相同但尚未记录，则更新最短路径
                    auto& paths = shortest_paths[neighbor_id];
                    if(paths.empty() || paths.front().size() > new_path.size() ||
                       (paths.front().size() == new_path.size() && std::find(paths.begin(), paths.end(), new_path) == paths.end())) {
                        if(paths.empty() || paths.front().size() > new_path.size()) {
                            paths = {new_path};
                        } else {
                            paths.push_back(new_path);
                        }
                        queue.push(new_path);
                    }
                }
            }
        }
    }
}

#else

void PathFinder::bfsFindGpuPaths(uint64_t start_node_id) {
    // 使用一个队列来存储当前路径
    std::queue<std::vector<uint64_t>> queue;
    // 使用一个unordered_map来存储每个node的最短路径
    std::unordered_map<uint64_t, std::vector<uint64_t>> shortest_paths;

    // 将起始节点加入队列
    queue.push({start_node_id});
    shortest_paths[start_node_id] = {start_node_id};
    // 当队列不为空时，继续搜索
    while(!queue.empty()) {
        // 从队列中取出一个路径
        auto path = queue.front();
        queue.pop();
        // 获取当前路径的最后一个节点的ID
        uint64_t nodeId = path.back();
        // 根据节点ID查找对应的节点
        const scclTopoNode_t* current_node = findNodeById(nodeId);
        if(current_node == nullptr) {
            continue;
        }
        // 如果当前节点是GPU节点且不是起始节点，则将当前路径加入结果
        if(current_node->type == GPU && nodeId != start_node_id) {
            int hipDev;
            bootstrap::physical_links::getIdComponents(current_node->id, nullptr, nullptr, nullptr, &hipDev, nullptr);
            if(hipDev < nLocalRanks) {
                gpu_paths_[start_node_id].push_back(path);
            }
        } else {
            int nodeInterRank;
            bootstrap::physical_links::getIdComponents(nodeId, &nodeInterRank);
            // 遍历当前节点的所有邻居节点
            for(uint64_t neighbor_id : graph_node_neighbors_.at(nodeId)) {
                if(findNodeById(neighbor_id) == nullptr) {
                    continue;
                }
                // 获取邻居节点的interRank
                int neighbor_inter_rank;
                bootstrap::physical_links::getIdComponents(neighbor_id, &neighbor_inter_rank);
                // 检查邻居节点是否已在当前路径中访问过
                bool visited = std::find(path.begin(), path.end(), neighbor_id) != path.end();
                // 检查interRank是否已经存在（仅当interRank改变时）
                bool inter_rank_exists = false;
                if(neighbor_inter_rank != nodeInterRank) {
                    for(uint64_t node_id : path) {
                        if(node_id == neighbor_id) {
                            inter_rank_exists = true;
                            break;
                        }
                    }
                }
                // 如果邻居节点未访问过且interRank未存在，则扩展路径
                if(!visited && !inter_rank_exists) {
                    std::vector<uint64_t> new_path = path;
                    new_path.push_back(neighbor_id);
                    // 如果新路径比已有的最短路径更短，则更新最短路径
                    if(shortest_paths.find(neighbor_id) == shortest_paths.end() || shortest_paths[neighbor_id].size() > new_path.size()) {
                        shortest_paths[neighbor_id] = new_path;
                        queue.push(new_path);
                    }
                }
            }
        }
    }
}

void PathFinder::bfsFindGpuPaths(uint64_t start_node_id) {
    // 使用一个队列来存储当前路径
    std::queue<std::vector<uint64_t>> queue;
    // 将起始节点加入队列
    queue.push({start_node_id});

    // 当队列不为空时，继续搜索
    while(!queue.empty()) {
        // 从队列中取出一个路径
        auto path = queue.front();
        queue.pop();

        // 获取当前路径的最后一个节点的ID
        uint64_t nodeId = path.back();
        // 根据节点ID查找对应的节点
        const scclTopoNode_t* current_node = findNodeById(nodeId);
        if(current_node == nullptr) {
            continue;
        }

        // 如果当前节点是GPU节点且不是起始节点，则将当前路径加入结果
        if(current_node->type == GPU && nodeId != start_node_id) {
            int hipDev;
            bootstrap::physical_links::getIdComponents(current_node->id, nullptr, nullptr, nullptr, &hipDev, nullptr);
            if(hipDev < nLocalRanks) {
                gpu_paths_[start_node_id].push_back(path);
            }
        } else {
            int nodeInterRank;
            bootstrap::physical_links::getIdComponents(nodeId, &nodeInterRank);

            // 遍历当前节点的所有邻居节点
            for(uint64_t neighbor_id : graph_node_neighbors_.at(nodeId)) {
                if(findNodeById(nodeId) == nullptr) {
                    continue;
                }

                // 获取邻居节点的interRank
                int neighbor_inter_rank;
                bootstrap::physical_links::getIdComponents(neighbor_id, &neighbor_inter_rank);

                // 检查邻居节点是否已在当前路径中访问过
                bool visited = std::find(path.begin(), path.end(), neighbor_id) != path.end();

                // 检查interRank是否已经存在（仅当interRank改变时）
                bool inter_rank_exists = false;
                if(neighbor_inter_rank != (nodeInterRank)) {
                    for(uint64_t node_id : path) {
                        if((nodeInterRank) == neighbor_inter_rank) {
                            inter_rank_exists = true;
                            break;
                        }
                    }
                }

                // 如果邻居节点未访问过且interRank未存在，则扩展路径
                if(!visited && !inter_rank_exists) {
                    std::vector<uint64_t> new_path = path;
                    new_path.push_back(neighbor_id);
                    queue.push(new_path);
                }
            }
        }
    }
}
#endif

/**
 * @brief 打印GPU路径信息
 *
 * 该函数用于打印`gpu_paths_`中存储的所有GPU路径信息。对于每一条路径，
 * 它会打印路径的长度以及路径中每个节点的详细信息，包括节点的`interRank`、
 * `deviceValue`、`terminalType`和`numaId`。
 */
void PathFinder::printGpuPaths() {
    // 遍历gpu_paths_中的每一对(start_node_id, paths)
    for(const auto& start_node_pair : gpu_paths_) {
        uint64_t start_node_id = start_node_pair.first; // 获取起始节点的ID

        char busIdStr[17] = ""; // 用于存储总线ID字符串
        // 根据起始节点的ID查找对应的节点对象
        const scclTopoNode_t* start_node = findNodeById(start_node_id);
        // 如果找到了对应的节点对象，则将其总线ID转换为字符串
        if(start_node) {
            int64ToBusId(start_node->busId, busIdStr);
        } else {
            return;
        }
        const auto& paths = start_node_pair.second; // 获取与起始节点关联的所有路径
        size_t path_count = paths.size();           // 计算路径的数量

        int interRank, deviceValue, terminalType, hipDev, numaId;
        // 根据起始节点的ID获取其interRank、deviceValue、terminalType和numaId
        bootstrap::physical_links::getIdComponents(start_node_id, &interRank, &deviceValue, &terminalType, &hipDev, &numaId);
        printf("GPU node ID:%lu (InterRank:%d, V:%d, T:%d, H:%d, N:%d) (Path count: %zu)\n",
               start_node_id,
               interRank,
               deviceValue,
               terminalType,
               hipDev,
               numaId,
               path_count);

        // 遍历与起始节点关联的所有路径
        for(const auto& path : paths) {
            size_t path_length = path.size(); // 计算路径的长度
            // 打印路径的长度
            printf("Path (length: %zu): ", path_length);

            // 遍历路径中的每个节点
            for(size_t i = 0; i < path.size(); ++i) {
                uint64_t node_id = path[i]; // 获取节点的ID
                // 使用findNodeById函数查找节点的详细信息
                const scclTopoNode_t* node = findNodeById(node_id);
                if(node) {
                    // 根据节点的ID获取其interRank、deviceValue、terminalType和numaId
                    bootstrap::physical_links::getIdComponents(node->id, &interRank, &deviceValue, &terminalType, &hipDev, &numaId);
                    // 将节点的总线ID转换为字符串
                    int64ToBusId(node->busId, busIdStr);
                    // 打印节点的信息，包括其interRank、deviceValue、terminalType、numaId、类型和总线ID字符串
                    printf("ID:%lu (InterRank:%d, V:%d, T:%d, H:%d, N:%d, type:%d, busIdStr:%s)",
                           node->id,
                           interRank,
                           deviceValue,
                           terminalType,
                           hipDev,
                           numaId,
                           node->type,
                           busIdStr);
                }
                // 如果当前节点不是路径中的最后一个节点，则打印" -> "以分隔节点
                if(i != path.size() - 1) {
                    printf(" -> ");
                }
            }
            // 换行，准备打印下一条路径
            printf("\n=============================================\n");
        }
    }
}

scclResult_t PathFinder::determineLinkType(const std::vector<uint64_t>& path, LinkType_t* link_type) {
    if(path.size() == 1) {
        *link_type = LINK_LOC;
    }

    bool has_gpu = false, has_pix = false, has_nic = false, has_cpu = false;
    // 遍历路径中的每个节点，从第2个点开始
    for(int i = 1; i < path.size(); i++) {
        uint64_t node_id = path[i];

        // 查找node_container_中对应的节点
        const scclTopoNode_t* node = findNodeById(node_id);
        if(node == nullptr) {
            WARN("cannot find node from id: %lu", node_id);
            return scclInternalError;
        }

        // 根据节点的类型确定连接方式的类型
        switch(node->type) {
            case GPU: has_gpu = true; break;
            case PCI: has_pix = true; break;
            case NIC: has_nic = true; break;
            case CPU: has_cpu = true; break;
            default: break;
        }
    }

    // 根据路径中节点的类型确定连接方式的类型
    if(has_cpu) {
        *link_type = LINK_NET;
    } else if(has_nic) {
        *link_type = LINK_PXN;
    } else if(has_pix) {
        *link_type = LINK_PIX;
    } else if(has_gpu) {
        *link_type = LINK_NVL;
    } else {
        *link_type = LINK_NONE; // 默认返回0
    }

    return scclSuccess;
}

} // namespace graph
} // namespace topology
} // namespace hardware
} // namespace sccl
