metis_cpu.cpp 3.85 KB
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include "metis_cpu.h"

#ifdef WITH_METIS
#include <metis.h>
#endif

#ifdef WITH_MTMETIS
#include <mtmetis.h>
#endif

#include "utils.h"

torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
                            torch::optional<torch::Tensor> optional_value,
                            torch::optional<torch::Tensor> optional_node_weight,
                            int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
  CHECK_CPU(rowptr);
  CHECK_CPU(col);

  if (optional_value.has_value()) {
    CHECK_CPU(optional_value.value());
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().numel() == col.numel());
  }

  if (optional_node_weight.has_value()) {
    CHECK_CPU(optional_node_weight.value());
    CHECK_INPUT(optional_node_weight.value().dim() == 1);
    CHECK_INPUT(optional_node_weight.value().numel() == rowptr.numel() - 1);
  }

  int64_t nvtxs = rowptr.numel() - 1;
  int64_t ncon = 1;
  auto *xadj = rowptr.data_ptr<int64_t>();
  auto *adjncy = col.data_ptr<int64_t>();

  int64_t *adjwgt = NULL;
  if (optional_value.has_value())
    adjwgt = optional_value.value().data_ptr<int64_t>();

  int64_t *vwgt = NULL;
  if (optional_node_weight.has_value())
    vwgt = optional_node_weight.value().data_ptr<int64_t>();

  int64_t objval = -1;
aiss's avatar
aiss committed
47
  auto part = torch::empty({nvtxs}, rowptr.options());
quyuanhao123's avatar
quyuanhao123 committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
  auto part_data = part.data_ptr<int64_t>();

  if (recursive) {
    METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
                             &num_parts, NULL, NULL, NULL, &objval, part_data);
  } else {
    METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
                        &num_parts, NULL, NULL, NULL, &objval, part_data);
  }

  return part;
#else
  AT_ERROR("Not compiled with METIS support");
#endif
}

// needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit
//             --partitions64bit
torch::Tensor
mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
                 torch::optional<torch::Tensor> optional_value,
                 torch::optional<torch::Tensor> optional_node_weight,
                 int64_t num_parts, bool recursive, int64_t num_workers) {
#ifdef WITH_MTMETIS
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  if (optional_value.has_value()) {
    CHECK_CPU(optional_value.value());
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().numel() == col.numel());
  }

  if (optional_node_weight.has_value()) {
    CHECK_CPU(optional_node_weight.value());
    CHECK_INPUT(optional_node_weight.value().dim() == 1);
    CHECK_INPUT(optional_node_weight.value().numel() == rowptr.numel() - 1);
  }

  mtmetis_vtx_type nvtxs = rowptr.numel() - 1;
  mtmetis_vtx_type ncon = 1;
  mtmetis_adj_type *xadj = (mtmetis_adj_type *)rowptr.data_ptr<int64_t>();
  mtmetis_vtx_type *adjncy = (mtmetis_vtx_type *)col.data_ptr<int64_t>();
  mtmetis_wgt_type *adjwgt = NULL;

  if (optional_value.has_value())
    adjwgt = optional_value.value().data_ptr<int64_t>();

  mtmetis_wgt_type *vwgt = NULL;
  if (optional_node_weight.has_value())
    vwgt = optional_node_weight.value().data_ptr<int64_t>();

  mtmetis_pid_type nparts = num_parts;
  mtmetis_wgt_type objval = -1;
aiss's avatar
aiss committed
102
  auto part = torch::empty({nvtxs}, rowptr.options());
quyuanhao123's avatar
quyuanhao123 committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();

  double *opts = mtmetis_init_options();
  opts[MTMETIS_OPTION_NTHREADS] = num_workers;

  if (recursive) {
    MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
                               &nparts, NULL, NULL, opts, &objval, part_data);
  } else {
    MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, vwgt, NULL, adjwgt,
                          &nparts, NULL, NULL, opts, &objval, part_data);
  }

  return part;
#else
  AT_ERROR("Not compiled with MTMETIS support");
#endif
}