metis_cpu.cpp 1.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "metis_cpu.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
#ifdef WITH_METIS
rusty1s's avatar
rusty1s committed
4
#include <metis.h>
rusty1s's avatar
rusty1s committed
5
#endif
rusty1s's avatar
rusty1s committed
6
7
8

#include "utils.h"

rusty1s's avatar
update  
rusty1s committed
9
10
11
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
                            torch::optional<torch::Tensor> optional_value,
                            int64_t num_parts, bool recursive) {
rusty1s's avatar
rusty1s committed
12
#ifdef WITH_METIS
rusty1s's avatar
rusty1s committed
13
14
  CHECK_CPU(rowptr);
  CHECK_CPU(col);
rusty1s's avatar
update  
rusty1s committed
15
16
17
18
19
  if (optional_value.has_value()) {
    CHECK_CPU(optional_value.value());
    CHECK_INPUT(optional_value.value().dim() == 1);
    CHECK_INPUT(optional_value.value().numel() == col.numel());
  }
rusty1s's avatar
rusty1s committed
20
21

  int64_t nvtxs = rowptr.numel() - 1;
22
  auto part = torch::empty(nvtxs, rowptr.options());
rusty1s's avatar
rusty1s committed
23
24
  auto *xadj = rowptr.data_ptr<int64_t>();
  auto *adjncy = col.data_ptr<int64_t>();
rusty1s's avatar
update  
rusty1s committed
25
26
27
  int64_t *adjwgt = NULL;
  if (optional_value.has_value())
    adjwgt = optional_value.value().data_ptr<int64_t>();
rusty1s's avatar
rusty1s committed
28
29
30
31
  int64_t ncon = 1;
  int64_t objval = -1;
  auto part_data = part.data_ptr<int64_t>();

rusty1s's avatar
rusty1s committed
32
  if (recursive) {
bwdeng20's avatar
bwdeng20 committed
33
    METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
rusty1s's avatar
rusty1s committed
34
35
                             &num_parts, NULL, NULL, NULL, &objval, part_data);
  } else {
bwdeng20's avatar
bwdeng20 committed
36
    METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, adjwgt,
rusty1s's avatar
rusty1s committed
37
38
                        &num_parts, NULL, NULL, NULL, &objval, part_data);
  }
rusty1s's avatar
rusty1s committed
39
40

  return part;
rusty1s's avatar
rusty1s committed
41
42
43
#else
  AT_ERROR("Not compiled with METIS support");
#endif
rusty1s's avatar
update  
rusty1s committed
44
}