shared_mem_manager.cc 4.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*!
 *  Copyright (c) 2018 by Contributors
 * \file graph/shared_mem_manager.cc
 * \brief DGL sampler implementation
 */
#include "shared_mem_manager.h"

#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/random.h>
#include <dgl/runtime/container.h>
#include <dgl/sampler.h>
#include <dmlc/io.h>
#include <dmlc/memory_io.h>

#include <algorithm>
#include <array>
#include <cmath>
#include <cstdlib>
#include <numeric>
#include <vector>

#include "../c_api_common.h"
#include "heterograph.h"

using namespace dgl::runtime;
using namespace dgl::aten;

namespace dgl {

template <>
NDArray SharedMemManager::CopyToSharedMem<NDArray>(const NDArray &data,
                                                   std::string name) {
  DLContext ctx = {kDLCPU, 0};
  std::vector<int64_t> shape(data->shape, data->shape + data->ndim);
  strm_->Write(data->ndim);
  strm_->Write(data->dtype);
  int ndim = data->ndim;
  strm_->WriteArray(data->shape, ndim);

  bool is_null = IsNullArray(data);
  strm_->Write(is_null);
  if (is_null) {
    return data;
  } else {
    auto nd =
      NDArray::EmptyShared(graph_name_ + name, shape, data->dtype, ctx, true);
    nd.CopyFrom(data);
    return nd;
  }
}

template <>
CSRMatrix SharedMemManager::CopyToSharedMem<CSRMatrix>(const CSRMatrix &csr,
                                                       std::string name) {
  auto indptr_shared_mem = CopyToSharedMem(csr.indptr, name + "_indptr");
  auto indices_shared_mem = CopyToSharedMem(csr.indices, name + "_indices");
  auto data_shared_mem = CopyToSharedMem(csr.data, name + "_data");
  strm_->Write(csr.num_rows);
  strm_->Write(csr.num_cols);
  strm_->Write(csr.sorted);
  return CSRMatrix(csr.num_rows, csr.num_cols, indptr_shared_mem,
                   indices_shared_mem, data_shared_mem, csr.sorted);
}

template <>
COOMatrix SharedMemManager::CopyToSharedMem<COOMatrix>(const COOMatrix &coo,
                                                       std::string name) {
  auto row_shared_mem = CopyToSharedMem(coo.row, name + "_row");
  auto col_shared_mem = CopyToSharedMem(coo.col, name + "_col");
  auto data_shared_mem = CopyToSharedMem(coo.data, name + "_data");
  strm_->Write(coo.num_rows);
  strm_->Write(coo.num_cols);
  strm_->Write(coo.row_sorted);
  strm_->Write(coo.col_sorted);
  return COOMatrix(coo.num_rows, coo.num_cols, row_shared_mem, col_shared_mem,
                   data_shared_mem, coo.row_sorted, coo.col_sorted);
}

template <>
bool SharedMemManager::CreateFromSharedMem<NDArray>(NDArray *nd,
                                                    std::string name) {
  int ndim;
  DLContext ctx = {kDLCPU, 0};
  DLDataType dtype;

  CHECK(this->Read(&ndim)) << "Invalid DLTensor file format";
  CHECK(this->Read(&dtype)) << "Invalid DLTensor file format";

  std::vector<int64_t> shape(ndim);
  if (ndim != 0) {
    CHECK(this->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
  }
  bool is_null;
  this->Read(&is_null);
  if (is_null) {
    *nd = NDArray::Empty(shape, dtype, ctx);
  } else {
    *nd =
      NDArray::EmptyShared(graph_name_ + name, shape, dtype, ctx, false);
  }
  return true;
}

template <>
bool SharedMemManager::CreateFromSharedMem<COOMatrix>(COOMatrix *coo,
                                                      std::string name) {
  CreateFromSharedMem(&coo->row, name + "_row");
  CreateFromSharedMem(&coo->col, name + "_col");
  CreateFromSharedMem(&coo->data, name + "_data");
  strm_->Read(&coo->num_rows);
  strm_->Read(&coo->num_cols);
  strm_->Read(&coo->row_sorted);
  strm_->Read(&coo->col_sorted);
  return true;
}

template <>
bool SharedMemManager::CreateFromSharedMem<CSRMatrix>(CSRMatrix *csr,
                                                      std::string name) {
  CreateFromSharedMem(&csr->indptr, name + "_indptr");
124
  CreateFromSharedMem(&csr->indices, name + "_indices");
125
126
127
128
129
130
131
132
  CreateFromSharedMem(&csr->data, name + "_data");
  strm_->Read(&csr->num_rows);
  strm_->Read(&csr->num_cols);
  strm_->Read(&csr->sorted);
  return true;
}

}  // namespace dgl