cpu_device_api.cc 3.38 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
/*!
2
 *  Copyright (c) 2016-2022 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
4
5
6
7
8
 * \file cpu_device_api.cc
 */
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
9
#include <dgl/runtime/tensordispatch.h>
Minjie Wang's avatar
Minjie Wang committed
10
11
12
13
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"

14
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
15
16
17
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
 public:
18
19
  void SetDevice(DGLContext ctx) final {}
  void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) final {
Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
    if (kind == kExist) {
      *rv = 1;
    }
  }
24
  void* AllocDataSpace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
25
26
                       size_t nbytes,
                       size_t alignment,
27
                       DGLType type_hint) final {
28
29
30
31
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CPUAllocWorkspace(nbytes);

Minjie Wang's avatar
Minjie Wang committed
32
    void* ptr;
33
#if _MSC_VER || defined(__MINGW32__)
Minjie Wang's avatar
Minjie Wang committed
34
35
36
37
38
39
40
41
42
43
44
45
    ptr = _aligned_malloc(nbytes, alignment);
    if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG)
    ptr = memalign(alignment, nbytes);
    if (ptr == nullptr) throw std::bad_alloc();
#else
    int ret = posix_memalign(&ptr, alignment, nbytes);
    if (ret != 0) throw std::bad_alloc();
#endif
    return ptr;
  }

46
  void FreeDataSpace(DGLContext ctx, void* ptr) final {
47
48
49
50
    TensorDispatcher* td = TensorDispatcher::Global();
    if (td->IsAvailable())
      return td->CPUFreeWorkspace(ptr);

51
#if _MSC_VER || defined(__MINGW32__)
Minjie Wang's avatar
Minjie Wang committed
52
53
54
55
56
57
58
59
60
61
62
    _aligned_free(ptr);
#else
    free(ptr);
#endif
  }

  void CopyDataFromTo(const void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t size,
63
64
65
66
                      DGLContext ctx_from,
                      DGLContext ctx_to,
                      DGLType type_hint,
                      DGLStreamHandle stream) final {
Minjie Wang's avatar
Minjie Wang committed
67
68
69
70
71
    memcpy(static_cast<char*>(to) + to_offset,
           static_cast<const char*>(from) + from_offset,
           size);
  }

72
73
  DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }

74
  void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
Minjie Wang's avatar
Minjie Wang committed
75
76
  }

77
78
  void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final;
  void FreeWorkspace(DGLContext ctx, void* data) final;
Minjie Wang's avatar
Minjie Wang committed
79
80
81
82
83
84
85
86
87
88
89
90
91

  static const std::shared_ptr<CPUDeviceAPI>& Global() {
    static std::shared_ptr<CPUDeviceAPI> inst =
        std::make_shared<CPUDeviceAPI>();
    return inst;
  }
};

struct CPUWorkspacePool : public WorkspacePool {
  CPUWorkspacePool() :
      WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};

92
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
93
                                   size_t size,
94
                                   DGLType type_hint) {
95
96
97
98
99
  TensorDispatcher* td = TensorDispatcher::Global();
  if (td->IsAvailable())
    return td->CPUAllocWorkspace(size);

  return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size);
Minjie Wang's avatar
Minjie Wang committed
100
101
}

102
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
103
104
105
106
  TensorDispatcher* td = TensorDispatcher::Global();
  if (td->IsAvailable())
    return td->CPUFreeWorkspace(data);

Minjie Wang's avatar
Minjie Wang committed
107
108
109
  dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}

110
111
DGL_REGISTER_GLOBAL("device_api.cpu")
.set_body([](DGLArgs args, DGLRetValue* rv) {
Minjie Wang's avatar
Minjie Wang committed
112
113
114
115
    DeviceAPI* ptr = CPUDeviceAPI::Global().get();
    *rv = static_cast<void*>(ptr);
  });
}  // namespace runtime
116
}  // namespace dgl