gpu_cache.cu 4.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/gpu_cache.cu
 * @brief GPUCache implementation on CUDA.
 */
#include <numeric>

#include "./common.h"
#include "./gpu_cache.h"

namespace graphbolt {
namespace cuda {

GpuCache::GpuCache(const std::vector<int64_t> &shape, torch::ScalarType dtype) {
  TORCH_CHECK(shape.size() >= 2, "Shape must at least have 2 dimensions.");
  const auto num_items = shape[0];
  const int64_t num_feats =
      std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>());
  const int element_size =
      torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size();
  num_bytes_ = num_feats * element_size;
  num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float);
  cache_ = std::make_unique<gpu_cache_t>(
      (num_items + bucket_size - 1) / bucket_size, num_float_feats_);
  shape_ = shape;
  shape_[0] = -1;
  dtype_ = dtype;
  device_id_ = cuda::GetCurrentStream().device_index();
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
    torch::Tensor keys) {
  TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
  TORCH_CHECK(
      keys.device().index() == device_id_,
      "Keys should be on the correct CUDA device.");
  TORCH_CHECK(keys.sizes().size() == 1, "Keys should be a 1D tensor.");
  keys = keys.to(torch::kLong);
  auto values = torch::empty(
      {keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat));
  auto missing_index =
      torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
  auto missing_keys =
      torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
  cuda::CopyScalar<size_t> missing_len;
  auto stream = cuda::GetCurrentStream();
  cache_->Query(
      reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
      values.data_ptr<float>(),
      reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
      reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
      stream);
  values = values.view(torch::kByte)
               .slice(1, 0, num_bytes_)
               .view(dtype_)
               .view(shape_);
  // To safely read missing_len, we synchronize
  stream.synchronize();
  missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
  missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
  return std::make_tuple(values, missing_index, missing_keys);
}

void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
  TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
  TORCH_CHECK(
      keys.device().index() == device_id_,
      "Keys should be on the correct CUDA device.");
  TORCH_CHECK(values.device().is_cuda(), "Keys should be on a CUDA device.");
  TORCH_CHECK(
      values.device().index() == device_id_,
      "Values should be on the correct CUDA device.");
  TORCH_CHECK(
      keys.size(0) == values.size(0),
      "The first dimensions of keys and values must match.");
  TORCH_CHECK(
      std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1),
      "Values should have the correct dimensions.");
  TORCH_CHECK(
      values.scalar_type() == dtype_, "Values should have the correct dtype.");
  keys = keys.to(torch::kLong);
  torch::Tensor float_values;
  if (num_bytes_ % sizeof(float) != 0) {
    float_values = torch::empty(
        {values.size(0), num_float_feats_},
        values.options().dtype(torch::kFloat));
    float_values.view(torch::kByte)
        .slice(1, 0, num_bytes_)
        .copy_(values.view(torch::kByte).view({values.size(0), -1}));
  } else {
    float_values = values.view(torch::kByte)
                       .view({values.size(0), -1})
                       .view(torch::kFloat)
                       .contiguous();
  }
  cache_->Replace(
      reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
      float_values.data_ptr<float>(), cuda::GetCurrentStream());
}

c10::intrusive_ptr<GpuCache> GpuCache::Create(
    const std::vector<int64_t> &shape, torch::ScalarType dtype) {
  return c10::make_intrusive<GpuCache>(shape, dtype);
}

}  // namespace cuda
}  // namespace graphbolt