hardware.cpp 2.73 KB
Newer Older
1
2
3
4
5
6
7
#include <stdint.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>

#include "base.h"
#include "hardware_utils.h"
#include "bootstrap.h"
8
9
#include "graph.h"
#include "hardware.h"
10
11
12
13
14

namespace sccl {
namespace hardware {

// 全局变量,全部节点的信息
15
16
17
18
19
typedef sccl::hardware::topology::bootstrap::BootstrapComm_t BootstrapComm_t;
typedef sccl::hardware::topology::graph::scclTopoGraph_t scclTopoGraph_t;

BootstrapComm_t* bootstrap_comm;
scclTopoGraph_t* topo_graph;
20
21
22
23

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
scclResult_t scclGetUniqueId(scclUniqueId* unique_id) {
24
25
26
    auto handle = reinterpret_cast<BootstrapHandle_t*>(unique_id);
    NEQCHECK(sizeof(BootstrapHandle_t), SCCL_UNIQUE_ID_BYTES);
    SCCLCHECK(topology::bootstrap::bootstrapGetUniqueId(handle));
27
28
29
30
31
    return scclSuccess;
}

scclResult_t sccl_init(const scclUniqueId* unique_id, int rank, int nRanks) {
    // -------------------------- 1.获取0号rank的地址信息 ----------------------------------- //
32
    auto root_handle = reinterpret_cast<const BootstrapHandle_t*>(unique_id);
33
34
35
    EQCHECK(root_handle->magic, 0); // 检查handle是否已经更新

    // -------------------------- 2.初始化获取所有节点的node信息 ----------------------------------- //
36
    auto sccl_bootstrap = std::make_unique<topology::bootstrap::Bootstrap>(root_handle, rank, nRanks);
37
38
39
40

    bootstrap_comm = new BootstrapComm_t();
    SCCLCHECK(sccl_bootstrap->init(bootstrap_comm));
    printf("init pos 1\n");
41

42
    // -------------------------- 3.MPI 建图 ----------------------------------- //
43
    topo_graph      = new scclTopoGraph_t(nRanks);
44
45
    auto sccl_graph = std::make_unique<topology::graph::Graph>(rank, nRanks);
    printf("init pos 2\n");
46

47
    // 计算通信路径
48
    SCCLCHECK(sccl_graph->calculateCommunicationPaths(bootstrap_comm, topo_graph, sccl_bootstrap.get()));
49
    printf("init pos 3\n");
50

51
    // -------------------------- 3.MPI allgather设置unique_id的整合 ----------------------------------- //
52
53
54

    // -------------------------- 5.根据各个节点的基础信息计算topo结果 ----------------------------------- //

55
56
57
58
    // // 后续放入到sccl_finalize中
    // delete bootstrap_comm;
    // delete topo_graph;

59
60
61
62
63
64
65
    return scclSuccess;
}

scclResult_t sccl_finalize() {
    // 设置一些全局变量的重置和销毁
    // 设置socket等硬件监听的关闭
    // void BootstrapComm::destroy() {
66
67
68
    // if(bootstrap_comm.nRanks > 0) {
    //     bootstrap_comm.destroy();
    // }
69
70
71
72
73
74

    return scclSuccess;
}

} // namespace hardware
} // namespace sccl