cpu_device_api.cc 3.58 KB
Newer Older
1
/**
2
 *  Copyright (c) 2016-2022 by Contributors
3
 * @file cpu_device_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
5
 */
#include <dgl/runtime/device_api.h>
6
#include <dgl/runtime/registry.h>
7
#include <dgl/runtime/tensordispatch.h>
8
9
10
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>

Minjie Wang's avatar
Minjie Wang committed
11
12
#include <cstdlib>
#include <cstring>
13

Minjie Wang's avatar
Minjie Wang committed
14
15
#include "workspace_pool.h"

16
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
17
18
19
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
 public:
20
21
  void SetDevice(DGLContext ctx) final {}
  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {
Minjie Wang's avatar
Minjie Wang committed
22
23
24
25
    if (kind == kExist) {
      *rv = 1;
    }
  }
26
27
28
  void* AllocDataSpace(
      DGLContext ctx, size_t nbytes, size_t alignment,
      DGLDataType type_hint) final {
29
30
31
    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();
    if (tensor_dispatcher->IsAvailable())
      return tensor_dispatcher->CPUAllocWorkspace(nbytes);
32

Minjie Wang's avatar
Minjie Wang committed
33
    void* ptr;
34
#if _MSC_VER || defined(__MINGW32__)
Minjie Wang's avatar
Minjie Wang committed
35
36
37
38
39
40
41
42
43
44
45
46
    ptr = _aligned_malloc(nbytes, alignment);
    if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG)
    ptr = memalign(alignment, nbytes);
    if (ptr == nullptr) throw std::bad_alloc();
#else
    int ret = posix_memalign(&ptr, alignment, nbytes);
    if (ret != 0) throw std::bad_alloc();
#endif
    return ptr;
  }

47
  void FreeDataSpace(DGLContext ctx, void* ptr) final {
48
49
50
    TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();
    if (tensor_dispatcher->IsAvailable())
      return tensor_dispatcher->CPUFreeWorkspace(ptr);
51

52
#if _MSC_VER || defined(__MINGW32__)
Minjie Wang's avatar
Minjie Wang committed
53
54
55
56
57
58
    _aligned_free(ptr);
#else
    free(ptr);
#endif
  }

59
60
61
62
63
64
65
  void CopyDataFromTo(
      const void* from, size_t from_offset, void* to, size_t to_offset,
      size_t size, DGLContext ctx_from, DGLContext ctx_to,
      DGLDataType type_hint) final {
    memcpy(
        static_cast<char*>(to) + to_offset,
        static_cast<const char*>(from) + from_offset, size);
Minjie Wang's avatar
Minjie Wang committed
66
67
  }

68
69
70
71
72
73
74
  void RecordedCopyDataFromTo(
      void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
      DGLContext ctx_from, DGLContext ctx_to, DGLDataType type_hint,
      void* pytorch_ctx) final {
    BUG_IF_FAIL(false) << "This piece of code should not be reached.";
  }

75
76
  DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }

77
  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}
Minjie Wang's avatar
Minjie Wang committed
78

79
80
  void* AllocWorkspace(
      DGLContext ctx, size_t size, DGLDataType type_hint) final;
81
  void FreeWorkspace(DGLContext ctx, void* data) final;
Minjie Wang's avatar
Minjie Wang committed
82
83
84
85
86
87
88
89
90

  static const std::shared_ptr<CPUDeviceAPI>& Global() {
    static std::shared_ptr<CPUDeviceAPI> inst =
        std::make_shared<CPUDeviceAPI>();
    return inst;
  }
};

struct CPUWorkspacePool : public WorkspacePool {
91
  CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
Minjie Wang's avatar
Minjie Wang committed
92
93
};

94
95
void* CPUDeviceAPI::AllocWorkspace(
    DGLContext ctx, size_t size, DGLDataType type_hint) {
96
97
98
99
  TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();
  if (tensor_dispatcher->IsAvailable()) {
    return tensor_dispatcher->CPUAllocWorkspace(size);
  }
100

101
102
  return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(
      ctx, size);
Minjie Wang's avatar
Minjie Wang committed
103
104
}

105
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
106
107
108
109
  TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();
  if (tensor_dispatcher->IsAvailable()) {
    return tensor_dispatcher->CPUFreeWorkspace(data);
  }
110

Minjie Wang's avatar
Minjie Wang committed
111
112
113
  dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}

114
DGL_REGISTER_GLOBAL("device_api.cpu")
115
116
117
118
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      DeviceAPI* ptr = CPUDeviceAPI::Global().get();
      *rv = static_cast<void*>(ptr);
    });
Minjie Wang's avatar
Minjie Wang committed
119
}  // namespace runtime
120
}  // namespace dgl