"src/diffusers/commands/env.py" did not exist on "fd768456511a8c41c6b87032ddde8bcfd9845290"
fps_cpu.cpp 1.76 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include "fps_cpu.h"

3
4
#include <ATen/Parallel.h>

rusty1s's avatar
rusty1s committed
5
6
7
#include "utils.h"

inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
rusty1s's avatar
rusty1s committed
8
  return (x - x[idx]).pow_(2).sum(1);
rusty1s's avatar
rusty1s committed
9
10
}

11
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
rusty1s's avatar
rusty1s committed
12
13
14
                      bool random_start) {

  CHECK_CPU(src);
rusty1s's avatar
rusty1s committed
15
  CHECK_CPU(ptr);
rusty1s's avatar
rusty1s committed
16
  CHECK_CPU(ratio);
rusty1s's avatar
rusty1s committed
17
  CHECK_INPUT(ptr.dim() == 1);
rusty1s's avatar
rusty1s committed
18
19

  src = src.view({src.size(0), -1}).contiguous();
rusty1s's avatar
rusty1s committed
20
  ptr = ptr.contiguous();
rusty1s's avatar
rusty1s committed
21
  auto batch_size = ptr.numel() - 1;
rusty1s's avatar
rusty1s committed
22
23

  auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
24
  auto out_ptr = deg.toType(torch::kFloat) * ratio;
rusty1s's avatar
rusty1s committed
25
26
  out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);

Matthias Fey's avatar
Matthias Fey committed
27
  auto out = torch::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());
rusty1s's avatar
rusty1s committed
28
29
30
31
32

  auto ptr_data = ptr.data_ptr<int64_t>();
  auto out_ptr_data = out_ptr.data_ptr<int64_t>();
  auto out_data = out.data_ptr<int64_t>();

33
34
35
36
37
38
  int64_t grain_size = 1; // Always parallelize over batch dimension.
  at::parallel_for(0, batch_size, grain_size, [&](int64_t begin, int64_t end) {
    int64_t src_start, src_end, out_start, out_end;
    for (int64_t b = begin; b < end; b++) {
      src_start = ptr_data[b], src_end = ptr_data[b + 1];
      out_start = b == 0 ? 0 : out_ptr_data[b - 1], out_end = out_ptr_data[b];
rusty1s's avatar
rusty1s committed
39

40
      auto y = src.narrow(0, src_start, src_end - src_start);
rusty1s's avatar
rusty1s committed
41

42
43
44
      int64_t start_idx = 0;
      if (random_start)
        start_idx = rand() % y.size(0);
rusty1s's avatar
rusty1s committed
45

46
47
      out_data[out_start] = src_start + start_idx;
      auto dist = get_dist(y, start_idx);
rusty1s's avatar
rusty1s committed
48

49
50
51
52
53
54
55
      for (int64_t i = 1; i < out_end - out_start; i++) {
        int64_t argmax = dist.argmax().data_ptr<int64_t>()[0];
        out_data[out_start + i] = src_start + argmax;
        dist = torch::min(dist, get_dist(y, argmax));
      }
    }
  });
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
58
  return out;
}