cpu_device_api.cc 3.09 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
/*!
2
 *  Copyright (c) 2016-2022 by Contributors
3
 * @file cpu_device_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
5
 */
#include <dgl/runtime/device_api.h>
6
#include <dgl/runtime/registry.h>
7
#include <dgl/runtime/tensordispatch.h>
8
9
10
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>

Minjie Wang's avatar
Minjie Wang committed
11
12
#include <cstdlib>
#include <cstring>
13

Minjie Wang's avatar
Minjie Wang committed
14
15
#include "workspace_pool.h"

16
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
17
18
19
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
 public:
20
21
  void SetDevice(DGLContext ctx) final {}
  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {
Minjie Wang's avatar
Minjie Wang committed
22
23
24
25
    if (kind == kExist) {
      *rv = 1;
    }
  }
26
27
28
  void* AllocDataSpace(
      DGLContext ctx, size_t nbytes, size_t alignment,
      DGLDataType type_hint) final {
29
    TensorDispatcher* td = TensorDispatcher::Global();
30
    if (td->IsAvailable()) return td->CPUAllocWorkspace(nbytes);
31

Minjie Wang's avatar
Minjie Wang committed
32
    void* ptr;
33
#if _MSC_VER || defined(__MINGW32__)
Minjie Wang's avatar
Minjie Wang committed
34
35
36
37
38
39
40
41
42
43
44
45
    ptr = _aligned_malloc(nbytes, alignment);
    if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG)
    ptr = memalign(alignment, nbytes);
    if (ptr == nullptr) throw std::bad_alloc();
#else
    int ret = posix_memalign(&ptr, alignment, nbytes);
    if (ret != 0) throw std::bad_alloc();
#endif
    return ptr;
  }

46
  void FreeDataSpace(DGLContext ctx, void* ptr) final {
47
    TensorDispatcher* td = TensorDispatcher::Global();
48
    if (td->IsAvailable()) return td->CPUFreeWorkspace(ptr);
49

50
#if _MSC_VER || defined(__MINGW32__)
Minjie Wang's avatar
Minjie Wang committed
51
52
53
54
55
56
    _aligned_free(ptr);
#else
    free(ptr);
#endif
  }

57
58
59
60
61
62
63
  void CopyDataFromTo(
      const void* from, size_t from_offset, void* to, size_t to_offset,
      size_t size, DGLContext ctx_from, DGLContext ctx_to,
      DGLDataType type_hint) final {
    memcpy(
        static_cast<char*>(to) + to_offset,
        static_cast<const char*>(from) + from_offset, size);
Minjie Wang's avatar
Minjie Wang committed
64
65
  }

66
67
  DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }

68
  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}
Minjie Wang's avatar
Minjie Wang committed
69

70
71
  void* AllocWorkspace(
      DGLContext ctx, size_t size, DGLDataType type_hint) final;
72
  void FreeWorkspace(DGLContext ctx, void* data) final;
Minjie Wang's avatar
Minjie Wang committed
73
74
75
76
77
78
79
80
81

  static const std::shared_ptr<CPUDeviceAPI>& Global() {
    static std::shared_ptr<CPUDeviceAPI> inst =
        std::make_shared<CPUDeviceAPI>();
    return inst;
  }
};

struct CPUWorkspacePool : public WorkspacePool {
82
  CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
Minjie Wang's avatar
Minjie Wang committed
83
84
};

85
86
void* CPUDeviceAPI::AllocWorkspace(
    DGLContext ctx, size_t size, DGLDataType type_hint) {
87
  TensorDispatcher* td = TensorDispatcher::Global();
88
  if (td->IsAvailable()) return td->CPUAllocWorkspace(size);
89

90
91
  return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(
      ctx, size);
Minjie Wang's avatar
Minjie Wang committed
92
93
}

94
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
95
  TensorDispatcher* td = TensorDispatcher::Global();
96
  if (td->IsAvailable()) return td->CPUFreeWorkspace(data);
97

Minjie Wang's avatar
Minjie Wang committed
98
99
100
  dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}

101
DGL_REGISTER_GLOBAL("device_api.cpu")
102
103
104
105
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      DeviceAPI* ptr = CPUDeviceAPI::Global().get();
      *rv = static_cast<void*>(ptr);
    });
Minjie Wang's avatar
Minjie Wang committed
106
}  // namespace runtime
107
}  // namespace dgl