api_container.cc 5.03 KB
Newer Older
1
/**
2
 *  Copyright (c) 2019 by Contributors
3
4
 * @file api/api_container.cc
 * @brief Runtime container APIs. (reference: tvm/src/api/api_lang.cc)
5
 */
6
#include <dgl/packed_func_ext.h>
7
#include <dgl/runtime/container.h>
8
#include <dgl/runtime/ndarray.h>
9
10
11
12
13
#include <dgl/runtime/registry.h>

namespace dgl {
namespace runtime {

14
15
16
17
18
19
20
DGL_REGISTER_GLOBAL("_List").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto ret_obj = std::make_shared<runtime::ListObject>();
  for (int i = 0; i < args.size(); ++i) {
    ret_obj->data.push_back(args[i].obj_sptr());
  }
  *rv = ret_obj;
});
21

22
23
24
25
26
27
28
29
DGL_REGISTER_GLOBAL("_ListGetItem").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  CHECK(sptr->is_type<ListObject>());
  auto* o = static_cast<const ListObject*>(sptr.get());
  int64_t i = args[1];
  CHECK_LT(i, o->data.size()) << "list out of bound";
  *rv = o->data[i];
});
30

31
32
33
34
35
36
DGL_REGISTER_GLOBAL("_ListSize").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  CHECK(sptr->is_type<ListObject>());
  auto* o = static_cast<const ListObject*>(sptr.get());
  *rv = static_cast<int64_t>(o->data.size());
});
37

38
39
40
41
42
43
44
45
DGL_REGISTER_GLOBAL("_Map").set_body([](DGLArgs args, DGLRetValue* rv) {
  CHECK_EQ(args.size() % 2, 0);
  if (args.size() != 0 && args[0].type_code() == kStr) {
    // StrMap
    StrMapObject::ContainerType data;
    for (int i = 0; i < args.size(); i += 2) {
      CHECK(args[i].type_code() == kStr) << "The key of the map must be string";
      CHECK(args[i + 1].type_code() == kObjectHandle)
46
          << "The value of the map must be an object type";
47
48
49
50
51
52
53
54
55
56
57
      data.emplace(std::make_pair(
          args[i].operator std::string(), args[i + 1].obj_sptr()));
    }
    auto obj = std::make_shared<StrMapObject>();
    obj->data = std::move(data);
    *rv = obj;
  } else {
    // object container
    MapObject::ContainerType data;
    for (int i = 0; i < args.size(); i += 2) {
      CHECK(args[i].type_code() == kObjectHandle)
58
          << "The key of the map must be an object type";
59
      CHECK(args[i + 1].type_code() == kObjectHandle)
60
          << "The value of the map must be an object type";
61
      data.emplace(std::make_pair(args[i].obj_sptr(), args[i + 1].obj_sptr()));
62
    }
63
64
65
66
67
    auto obj = std::make_shared<MapObject>();
    obj->data = std::move(data);
    *rv = obj;
  }
});
68

69
70
71
72
73
74
75
DGL_REGISTER_GLOBAL("_EmptyStrMap").set_body([](DGLArgs args, DGLRetValue* rv) {
  StrMapObject::ContainerType data;
  auto obj = std::make_shared<StrMapObject>();
  obj->data = std::move(data);
  *rv = obj;
});

76
77
78
79
80
81
82
83
84
85
86
DGL_REGISTER_GLOBAL("_MapSize").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  if (sptr->is_type<MapObject>()) {
    auto* o = static_cast<const MapObject*>(sptr.get());
    *rv = static_cast<int64_t>(o->data.size());
  } else {
    CHECK(sptr->is_type<StrMapObject>());
    auto* o = static_cast<const StrMapObject*>(sptr.get());
    *rv = static_cast<int64_t>(o->data.size());
  }
});
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
DGL_REGISTER_GLOBAL("_MapGetItem").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  if (sptr->is_type<MapObject>()) {
    auto* o = static_cast<const MapObject*>(sptr.get());
    auto it = o->data.find(args[1].obj_sptr());
    CHECK(it != o->data.end()) << "cannot find the key in the map";
    *rv = (*it).second;
  } else {
    CHECK(sptr->is_type<StrMapObject>());
    auto* o = static_cast<const StrMapObject*>(sptr.get());
    auto it = o->data.find(args[1].operator std::string());
    CHECK(it != o->data.end()) << "cannot find the key in the map";
    *rv = (*it).second;
  }
});
103

104
105
106
107
108
109
110
111
DGL_REGISTER_GLOBAL("_MapItems").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  if (sptr->is_type<MapObject>()) {
    auto* o = static_cast<const MapObject*>(sptr.get());
    auto rkvs = std::make_shared<ListObject>();
    for (const auto& kv : o->data) {
      rkvs->data.push_back(kv.first);
      rkvs->data.push_back(kv.second);
112
    }
113
114
115
116
117
118
119
120
    *rv = rkvs;
  } else {
    CHECK(sptr->is_type<StrMapObject>());
    auto* o = static_cast<const StrMapObject*>(sptr.get());
    auto rkvs = std::make_shared<ListObject>();
    for (const auto& kv : o->data) {
      rkvs->data.push_back(MakeValue(kv.first));
      rkvs->data.push_back(kv.second);
121
    }
122
123
124
    *rv = rkvs;
  }
});
125

126
127
128
129
130
131
132
133
134
135
136
DGL_REGISTER_GLOBAL("_MapCount").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  if (sptr->is_type<MapObject>()) {
    auto* o = static_cast<const MapObject*>(sptr.get());
    *rv = static_cast<int64_t>(o->data.count(args[1].obj_sptr()));
  } else {
    CHECK(sptr->is_type<StrMapObject>());
    auto* o = static_cast<const StrMapObject*>(sptr.get());
    *rv = static_cast<int64_t>(o->data.count(args[1].operator std::string()));
  }
});
137

138
139
140
141
142
143
144
145
146
147
DGL_REGISTER_GLOBAL("_Value").set_body([](DGLArgs args, DGLRetValue* rv) {
  *rv = MakeValue(args[0]);
});

DGL_REGISTER_GLOBAL("_ValueGet").set_body([](DGLArgs args, DGLRetValue* rv) {
  auto& sptr = args[0].obj_sptr();
  CHECK(sptr->is_type<ValueObject>());
  auto* o = static_cast<const ValueObject*>(sptr.get());
  *rv = o->data;
});
148
149
150

}  // namespace runtime
}  // namespace dgl