metis_cpu.cpp 3.05 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "metis_cpu.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
#ifdef WITH_METIS
rusty1s's avatar
rusty1s committed
4
#include <metis.h>
rusty1s's avatar
rusty1s committed
5
#endif
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
8
9
10
#ifdef WITH_MTMETIS
#include <mtmetis.h>
#endif

rusty1s's avatar
rusty1s committed
11
12
#include "utils.h"

rusty1s's avatar
update  
rusty1s committed
13
14
15
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
                            torch::optional<torch::Tensor> optional_value,
                            int64_t num_parts, bool recursive) {
rusty1s's avatar
rusty1s committed
16
#ifdef WITH_METIS
rusty1s's avatar
rusty1s committed
17
18
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
rusty1s's avatar
update  
rusty1s committed
19
20
21
22
23
  if (optional_value.has_value()) {
    CHECK_CPU(optional_value.value());
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().numel() == col.numel());
  }
rusty1s's avatar
rusty1s committed
24
25

  int64_t nvtxs = rowptr.numel() - 1;
rusty1s's avatar
rusty1s committed
26
  int64_t ncon = 1;
rusty1s's avatar
rusty1s committed
27
28
  auto *xadj = rowptr.data_ptr<int64_t>();
  auto *adjncy = col.data_ptr<int64_t>();
rusty1s's avatar
update  
rusty1s committed
29
30
31
  int64_t *adjwgt = NULL;
  if (optional_value.has_value())
    adjwgt = optional_value.value().data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
32
  int64_t objval = -1;
rusty1s's avatar
rusty1s committed
33
  auto part = torch::empty(nvtxs, rowptr.options());
rusty1s's avatar
rusty1s committed
34
35
  auto part_data = part.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
36
  if (recursive) {
bwdeng20's avatar
bwdeng20 committed
37
    METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
rusty1s's avatar
rusty1s committed
38
39
                             &num_parts, NULL, NULL, NULL, &objval, part_data);
  } else {
bwdeng20's avatar
bwdeng20 committed
40
    METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
rusty1s's avatar
rusty1s committed
41
42
                        &num_parts, NULL, NULL, NULL, &objval, part_data);
  }
rusty1s's avatar
rusty1s committed
43
44

  return part;
rusty1s's avatar
rusty1s committed
45
46
47
#else
  AT_ERROR("Not compiled with METIS support");
#endif
rusty1s's avatar
update  
rusty1s committed
48
}
rusty1s's avatar
rusty1s committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

// needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit
//             --partitions64bit
torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
                               torch::optional<torch::Tensor> optional_value,
                               int64_t num_parts, bool recursive,
                               int64_t num_workers) {
#ifdef WITH_MTMETIS
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
  if (optional_value.has_value()) {
    CHECK_CPU(optional_value.value());
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().numel() == col.numel());
  }

  mtmetis_vtx_type nvtxs = rowptr.numel() - 1;
  mtmetis_vtx_type ncon = 1;
  mtmetis_adj_type *xadj = (mtmetis_adj_type *)rowptr.data_ptr<int64_t>();
  mtmetis_vtx_type *adjncy = (mtmetis_vtx_type *)col.data_ptr<int64_t>();
  mtmetis_wgt_type *adjwgt = NULL;
  if (optional_value.has_value())
    adjwgt = optional_value.value().data_ptr<int64_t>();
  mtmetis_pid_type nparts = num_parts;
  mtmetis_wgt_type objval = -1;
  auto part = torch::empty(nvtxs, rowptr.options());
  mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();

  double *opts = mtmetis_init_options();
  opts[MTMETIS_OPTION_NTHREADS] = num_workers;

  if (recursive) {
    MTMETIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
                               &nparts, NULL, NULL, opts, &objval, part_data);
  } else {
    MTMETIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
                          &nparts, NULL, NULL, opts, &objval, part_data);
  }

  return part;
#else
  AT_ERROR("Not compiled with MTMETIS support");
#endif
}