workspace_pool.cc 4.66 KB
Newer Older
1
/**
Minjie Wang's avatar
Minjie Wang committed
2
 *  Copyright (c) 2017 by Contributors
3
4
 * @file workspace_pool.h
 * @brief Workspace pool utility.
Minjie Wang's avatar
Minjie Wang committed
5
6
 */
#include "workspace_pool.h"
7

8
#include <memory>
Minjie Wang's avatar
Minjie Wang committed
9

10
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
namespace runtime {

// page size.
constexpr size_t kWorkspacePageSize = 4 << 10;

class WorkspacePool::Pool {
 public:
  // constructor
  Pool() {
    // safe guard header on each list.
    Entry e;
    e.data = nullptr;
    e.size = 0;
    free_list_.push_back(e);
    allocated_.push_back(e);
  }
  // allocate from pool
28
  void* Alloc(DGLContext ctx, DeviceAPI* device, size_t nbytes) {
Minjie Wang's avatar
Minjie Wang committed
29
    // Allocate align to page.
30
31
    nbytes = (nbytes + (kWorkspacePageSize - 1)) / kWorkspacePageSize *
             kWorkspacePageSize;
Minjie Wang's avatar
Minjie Wang committed
32
33
    if (nbytes == 0) nbytes = kWorkspacePageSize;
    Entry e;
34
35
    DGLDataType type;
    type.code = kDGLUInt;
Minjie Wang's avatar
Minjie Wang committed
36
37
38
39
40
41
42
43
    type.bits = 8;
    type.lanes = 1;
    if (free_list_.size() == 2) {
      e = free_list_.back();
      free_list_.pop_back();
      if (e.size < nbytes) {
        // resize the page
        device->FreeDataSpace(ctx, e.data);
44
45
        e.data =
            device->AllocDataSpace(ctx, nbytes, kTempAllocaAlignment, type);
Minjie Wang's avatar
Minjie Wang committed
46
47
48
49
50
51
52
53
54
        e.size = nbytes;
      }
    } else if (free_list_.size() == 1) {
      e.data = device->AllocDataSpace(ctx, nbytes, kTempAllocaAlignment, type);
      e.size = nbytes;
    } else {
      if (free_list_.back().size >= nbytes) {
        // find smallest fit
        auto it = free_list_.end() - 2;
55
56
        for (; it->size >= nbytes; --it) {
        }
Minjie Wang's avatar
Minjie Wang committed
57
58
59
60
61
62
63
        e = *(it + 1);
        free_list_.erase(it + 1);
      } else {
        // resize the page
        e = free_list_.back();
        free_list_.pop_back();
        device->FreeDataSpace(ctx, e.data);
64
65
        e.data =
            device->AllocDataSpace(ctx, nbytes, kTempAllocaAlignment, type);
Minjie Wang's avatar
Minjie Wang committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        e.size = nbytes;
      }
    }
    allocated_.push_back(e);
    return e.data;
  }
  // free resource back to pool
  void Free(void* data) {
    Entry e;
    if (allocated_.back().data == data) {
      // quick path, last allocated.
      e = allocated_.back();
      allocated_.pop_back();
    } else {
      int index = static_cast<int>(allocated_.size()) - 2;
81
82
      for (; index > 0 && allocated_[index].data != data; --index) {
      }
Minjie Wang's avatar
Minjie Wang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
      CHECK_GT(index, 0) << "trying to free things that has not been allocated";
      e = allocated_[index];
      allocated_.erase(allocated_.begin() + index);
    }
    if (free_list_.back().size < e.size) {
      free_list_.push_back(e);
    } else if (free_list_.size() == 2) {
      free_list_.push_back(free_list_.back());
      free_list_[1] = e;
    } else {
      size_t i = free_list_.size() - 1;
      free_list_.resize(free_list_.size() + 1);
      for (; e.size < free_list_[i].size; --i) {
        free_list_[i + 1] = free_list_[i];
      }
      free_list_[i + 1] = e;
    }
  }
  // Release all resources
102
  void Release(DGLContext ctx, DeviceAPI* device) {
Minjie Wang's avatar
Minjie Wang committed
103
104
105
106
107
108
109
110
    CHECK_EQ(allocated_.size(), 1);
    for (size_t i = 1; i < free_list_.size(); ++i) {
      device->FreeDataSpace(ctx, free_list_[i].data);
    }
    free_list_.clear();
  }

 private:
111
  /** @brief a single entry in the pool */
Minjie Wang's avatar
Minjie Wang committed
112
113
114
115
  struct Entry {
    void* data;
    size_t size;
  };
116
  /** @brief List of free items, sorted from small to big size */
Minjie Wang's avatar
Minjie Wang committed
117
  std::vector<Entry> free_list_;
118
  /** @brief List of allocated items */
Minjie Wang's avatar
Minjie Wang committed
119
120
121
  std::vector<Entry> allocated_;
};

122
123
124
WorkspacePool::WorkspacePool(
    DGLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
    : device_type_(device_type), device_(device) {}
Minjie Wang's avatar
Minjie Wang committed
125
126

WorkspacePool::~WorkspacePool() {
127
128
129
130
  /**
   * Comment out the destruct of WorkspacePool, due to Segmentation fault with
   * MXNet Since this will be only called at the termination of process, not
   * manually wiping out should not cause problems.
131
132
   * Note, this will cause memory leak without the following code, so, maybe
   * we need to solve the problem.
133
   */
VoVAllen's avatar
VoVAllen committed
134
135
136
137
138
139
140
141
142
  // for (size_t i = 0; i < array_.size(); ++i) {
  //   if (array_[i] != nullptr) {
  //     DGLContext ctx;
  //     ctx.device_type = device_type_;
  //     ctx.device_id = static_cast<int>(i);
  //     array_[i]->Release(ctx, device_.get());
  //     delete array_[i];
  //   }
  // }
Minjie Wang's avatar
Minjie Wang committed
143
144
}

145
void* WorkspacePool::AllocWorkspace(DGLContext ctx, size_t size) {
Minjie Wang's avatar
Minjie Wang committed
146
147
148
149
150
151
152
153
154
  if (static_cast<size_t>(ctx.device_id) >= array_.size()) {
    array_.resize(ctx.device_id + 1, nullptr);
  }
  if (array_[ctx.device_id] == nullptr) {
    array_[ctx.device_id] = new Pool();
  }
  return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);
}

155
void WorkspacePool::FreeWorkspace(DGLContext ctx, void* ptr) {
156
157
158
  CHECK(
      static_cast<size_t>(ctx.device_id) < array_.size() &&
      array_[ctx.device_id] != nullptr);
Minjie Wang's avatar
Minjie Wang committed
159
160
161
162
  array_[ctx.device_id]->Free(ptr);
}

}  // namespace runtime
163
}  // namespace dgl