unique_and_compact.cc 4.6 KB
Newer Older
1
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 *
 * @file unique_and_compact.cc
 * @brief Unique and compact op.
 */

8
#include <graphbolt/cuda_ops.h>
9
10
#include <graphbolt/unique_and_compact.h>

11
12
#include <unordered_map>

13
#include "./concurrent_id_hash_map.h"
14
15
#include "./macro.h"
#include "./utils.h"
16
17
18
19
20
21

namespace graphbolt {
namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
    const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
    const torch::Tensor unique_dst_ids) {
22
23
  if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) &&
      utils::is_on_gpu(unique_dst_ids)) {
24
25
26
27
    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
        c10::DeviceType::CUDA, "unique_and_compact",
        { return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
  }
28
29
30
31
32
  torch::Tensor compacted_src_ids;
  torch::Tensor compacted_dst_ids;
  torch::Tensor unique_ids;
  auto num_dst = unique_dst_ids.size(0);
  torch::Tensor ids = torch::cat({unique_dst_ids, src_ids});
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// TODO: Remove this after windows concurrent bug being fixed.
#ifdef _MSC_VER
  AT_DISPATCH_INTEGRAL_TYPES(
      ids.scalar_type(), "unique_and_compact", ([&] {
        std::unordered_map<scalar_t, scalar_t> id_map;
        unique_ids = torch::empty_like(ids);
        auto unique_ids_data = unique_ids.data_ptr<scalar_t>();
        auto ids_data = ids.data_ptr<scalar_t>();
        auto num_ids = ids.size(0);
        scalar_t index = 0;
        for (auto i = 0; i < num_ids; i++) {
          auto id = ids_data[i];
          if (id_map.count(id) == 0) {
            unique_ids_data[index] = id;
            id_map[id] = index++;
          }
        }
        unique_ids = unique_ids.slice(0, 0, index);
        compacted_src_ids = torch::empty_like(src_ids);
        compacted_dst_ids = torch::empty_like(dst_ids);
        num_ids = compacted_src_ids.size(0);
        auto src_ids_data = src_ids.data_ptr<scalar_t>();
        auto dst_ids_data = dst_ids.data_ptr<scalar_t>();
        auto compacted_src_ids_data = compacted_src_ids.data_ptr<scalar_t>();
        auto compacted_dst_ids_data = compacted_dst_ids.data_ptr<scalar_t>();
        torch::parallel_for(0, num_ids, 256, [&](int64_t s, int64_t e) {
          for (int64_t i = s; i < e; i++) {
            auto it = id_map.find(src_ids_data[i]);
            if (it == id_map.end())
              throw std::out_of_range(
                  "Id not found: " + std::to_string(src_ids_data[i]));
            compacted_src_ids_data[i] = it->second;
          }
        });
        num_ids = compacted_dst_ids.size(0);
        torch::parallel_for(0, num_ids, 256, [&](int64_t s, int64_t e) {
          for (int64_t i = s; i < e; i++) {
            auto it = id_map.find(dst_ids_data[i]);
            if (it == id_map.end())
              throw std::out_of_range(
                  "Id not found: " + std::to_string(dst_ids_data[i]));
            compacted_dst_ids_data[i] = it->second;
          }
        });
      }));
#else
79
80
81
82
83
84
  AT_DISPATCH_INTEGRAL_TYPES(ids.scalar_type(), "unique_and_compact", ([&] {
                               ConcurrentIdHashMap<scalar_t> id_map;
                               unique_ids = id_map.Init(ids, num_dst);
                               compacted_src_ids = id_map.MapIds(src_ids);
                               compacted_dst_ids = id_map.MapIds(dst_ids);
                             }));
85
#endif
86
87
  return std::tuple(unique_ids, compacted_src_ids, compacted_dst_ids);
}
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
    const std::vector<torch::Tensor>& src_ids,
    const std::vector<torch::Tensor>& dst_ids,
    const std::vector<torch::Tensor> unique_dst_ids) {
  TORCH_CHECK(
      src_ids.size() == dst_ids.size() &&
          dst_ids.size() == unique_dst_ids.size(),
      "The batch dimension of the parameters need to be identical.");
  bool all_on_gpu = true;
  for (std::size_t i = 0; i < src_ids.size(); i++) {
    all_on_gpu = all_on_gpu && utils::is_on_gpu(src_ids[i]) &&
                 utils::is_on_gpu(dst_ids[i]) &&
                 utils::is_on_gpu(unique_dst_ids[i]);
    if (!all_on_gpu) break;
  }
  if (all_on_gpu) {
    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
        c10::DeviceType::CUDA, "unique_and_compact", {
          return ops::UniqueAndCompactBatched(src_ids, dst_ids, unique_dst_ids);
        });
  }
  std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> results;
  results.reserve(src_ids.size());
  for (std::size_t i = 0; i < src_ids.size(); i++) {
    results.emplace_back(
        UniqueAndCompact(src_ids[i], dst_ids[i], unique_dst_ids[i]));
  }
  return results;
}

120
121
}  // namespace sampling
}  // namespace graphbolt