"awq_cuda/pybind_windows.cpp" did not exist on "a11c313a4d451587b97300e99eff1d1bcdd82d7e"
hardware.cpp 4.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <stdint.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>

#include "base.h"
#include "hardware_utils.h"
#include "bootstrap.h"

namespace sccl {
namespace hardware {
namespace topology {
namespace bootstrap {

// 全局变量,全部节点的信息
struct BootstrapComm bootstrap_comm;

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
scclResult_t scclGetUniqueId(scclUniqueId* unique_id) {
    auto handle = reinterpret_cast<struct BootstrapHandle*>(unique_id);
    NEQCHECK(sizeof(struct BootstrapHandle), SCCL_UNIQUE_ID_BYTES);
    SCCLCHECK(bootstrapGetUniqueId(handle));
    return scclSuccess;
}

scclResult_t sccl_init(const scclUniqueId* unique_id, int rank, int nRanks) {
    // -------------------------- 1.获取0号rank的地址信息 ----------------------------------- //
    auto root_handle = reinterpret_cast<const struct BootstrapHandle*>(unique_id);
    EQCHECK(root_handle->magic, 0); // 检查handle是否已经更新

    // -------------------------- 2.初始化获取所有节点的node信息 ----------------------------------- //
    auto sccl_bootstrap = std::make_unique<Bootstrap>(root_handle, rank, nRanks);
    SCCLCHECK(sccl_bootstrap->init(&bootstrap_comm));

    // // -------------------------- 3.MPI allgather设置unique_id的整合 ----------------------------------- //

    // auto unique_ids_chr = reinterpret_cast<const char*>(unique_ids);

    // // -------------------------- 3.MPI allgather设置unique_id的整合 ----------------------------------- //
    // std::vector<scclUniqueId> unique_id_vec(nRanks);
    // MPI_Allgather(&unique_id, sizeof(scclUniqueId), MPI_BYTE, &unique_id_vec[0], sizeof(scclUniqueId), MPI_BYTE, MPI_COMM_WORLD);

    // for(int i = 0; i < nRanks; ++i) {
    //     auto root_handle = reinterpret_cast<const struct BootstrapHandle*>(unique_ids_chr + i * sizeof(struct BootstrapHandle));
    //     printf("rank=%d, i=%d, unique_ids hosthash=%lu\n", root_handle->rank, i, root_handle->hostHash);
    // }

    // ByteSpan<struct BootstrapHandle> unique_ids_span(unique_ids_chr, nRanks * sizeof(struct BootstrapHandle));

    // // -------------------------- 2.设置基础信息 ----------------------------------- //
    // INFO(SCCL_LOG_TOPO, "Bootstrap ...\n");
    // struct scclRankInfo rank_info;
    // rank_info.rank   = rank;
    // rank_info.nRanks = nRanks;

    //     // 在每个进程中设置 root_handle 的值
    //     root_handle.rank               = rank_info->rank;
    //     root_handle.hostHash           = getHostHash();
    //     scclSocketAddress_t localSocketAddr = sccl_bootstrap->getLocalSocketAddr();
    //     memcpy(&root_handle.addr, &localSocketAddr, sizeof(scclSocketAddress_t));

    // #if 1
    //     char line[100];
    //     sprintf(line, "pos 55: rank=%d", rank);
    //     SCCLCHECK(hardware::net::printSocketAddr(&root_handle.addr, line));
    //     printf("root_handle.hostHash rank=%d, hash=%lu\n", rank, root_handle.hostHash);
    // #endif

    // // -------------------------- 3.收集所有进程的 root_handle 信息 ----------------------------------- //

    // std::vector<char> recvBuffer(nRanks * sendBuffer.size());

    // SCCLCHECK(mpi::wrap_mpi_allgather(sendBuffer.data(), sendBuffer.size(), MPI_BYTE, recvBuffer.data(), sendBuffer.size(), MPI_BYTE, MPI_COMM_WORLD));

    // -------------------------- 4.设置各个节点的基础信息 ----------------------------------- //
    // SCCLCHECK(sccl_bootstrap->bootstrapInit(rank_info, recvBuffer.data()));

    // -------------------------- 5.根据各个节点的基础信息计算topo结果 ----------------------------------- //

    return scclSuccess;
}

scclResult_t sccl_finalize() {
    // 设置一些全局变量的重置和销毁
    // 设置socket等硬件监听的关闭
    // void BootstrapComm::destroy() {
    if(bootstrap_comm.nRanks > 0) {
        bootstrap_comm.destroy();
    }

    return scclSuccess;
}

} // namespace bootstrap
} // namespace topology
} // namespace hardware
} // namespace sccl