"vscode:/vscode.git/clone" did not exist on "ca4a46eff11adc0881351db3e45378d23b521b92"
Unverified Commit 913e3249 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Add device check for sampler input (#1145)

current samplers only support working on CPU
parent d57ff78d
......@@ -877,6 +877,10 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_nodes));
CHECK_EQ(seed_nodes->ctx.device_type, kDLCPU)
<< "UniformSampler only support CPU sampling";
std::vector<NodeFlow> nflows = NeighborSamplingImpl<float>(
gptr, seed_nodes, batch_start_id, batch_size, max_num_workers,
expand_factor, num_hops, neigh_type, add_self_loop, nullptr);
......@@ -901,12 +905,18 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_nodes));
CHECK_EQ(seed_nodes->ctx.device_type, kDLCPU)
<< "NeighborSampler only support CPU sampling";
std::vector<NodeFlow> nflows;
CHECK(probability->dtype.code == kDLFloat)
<< "transition probability must be float";
CHECK(probability->ndim == 1)
<< "transition probability must be a 1-dimensional vector";
CHECK_EQ(probability->ctx.device_type, kDLCPU)
<< "NeighborSampling only support CPU sampling";
ATEN_FLOAT_TYPE_SWITCH(
probability->dtype,
......@@ -947,6 +957,13 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_nodes));
CHECK_EQ(seed_nodes->ctx.device_type, kDLCPU)
<< "LayerSampler only support CPU sampling";
CHECK(aten::IsValidIdArray(layer_sizes));
CHECK_EQ(layer_sizes->ctx.device_type, kDLCPU)
<< "LayerSampler only support CPU sampling";
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
......@@ -1570,6 +1587,14 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges));
CHECK_EQ(seed_edges->ctx.device_type, kDLCPU)
<< "UniformEdgeSampler only support CPU sampling";
if (relations->shape[0] > 0) {
CHECK(aten::IsValidIdArray(relations));
CHECK_EQ(relations->ctx.device_type, kDLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
}
BuildCoo(*gptr);
auto o = std::make_shared<UniformEdgeSamplerObject>(gptr,
......@@ -1842,11 +1867,22 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(aten::IsValidIdArray(seed_edges));
CHECK_EQ(seed_edges->ctx.device_type, kDLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
CHECK(edge_weight->dtype.code == kDLFloat) << "edge_weight should be FloatType";
CHECK(edge_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
CHECK_EQ(edge_weight->ctx.device_type, kDLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
if (node_weight->shape[0] > 0) {
CHECK(node_weight->dtype.code == kDLFloat) << "node_weight should be FloatType";
CHECK(node_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
CHECK_EQ(node_weight->ctx.device_type, kDLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
}
if (relations->shape[0] > 0) {
CHECK(aten::IsValidIdArray(relations));
CHECK_EQ(relations->ctx.device_type, kDLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
}
BuildCoo(*gptr);
......
......@@ -83,6 +83,12 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
const IdArray seeds = args[2];
int num_traces = args[3];
CHECK(aten::IsValidIdArray(etypes));
CHECK_EQ(etypes->ctx.device_type, kDLCPU)
<< "MetapathRandomWalk only support CPU sampling";
CHECK(aten::IsValidIdArray(seeds));
CHECK_EQ(seeds->ctx.device_type, kDLCPU)
<< "MetapathRandomWalk only support CPU sampling";
const auto tl = MetapathRandomWalk(hg.sptr(), etypes, seeds, num_traces);
*rv = RandomWalkTracesRef(tl);
});
......
......@@ -213,6 +213,10 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalk")
const int num_traces = args[2];
const int num_hops = args[3];
CHECK(aten::IsValidIdArray(seeds));
CHECK_EQ(seeds->ctx.device_type, kDLCPU)
<< "RandomWalk only support CPU sampling";
*rv = RandomWalk(g.sptr().get(), seeds, num_traces, num_hops);
});
......@@ -225,6 +229,10 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalkWithRestart")
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
CHECK(aten::IsValidIdArray(seeds));
CHECK_EQ(seeds->ctx.device_type, kDLCPU)
<< "RandomWalkWithRestart only support CPU sampling";
*rv = RandomWalkTracesRef(
RandomWalkWithRestart(g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
max_visit_counts, max_frequent_visited_nodes));
......@@ -239,6 +247,10 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkW
const uint64_t max_visit_counts = args[4];
const uint64_t max_frequent_visited_nodes = args[5];
CHECK(aten::IsValidIdArray(seeds));
CHECK_EQ(seeds->ctx.device_type, kDLCPU)
<< "BipartiteSingleSidedRandomWalkWithRestart only support CPU sampling";
*rv = RandomWalkTracesRef(
BipartiteSingleSidedRandomWalkWithRestart(
g.sptr().get(), seeds, restart_prob, visit_threshold_per_seed,
......
......@@ -591,6 +591,7 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
for pos_edges, neg_edges in EdgeSampler(g, batch_size,
replacement=True,
edge_weight=edge_weight,
shuffle=True,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=False,
......@@ -630,6 +631,7 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
replacement=True,
edge_weight=edge_weight,
node_weight=node_weight,
shuffle=True,
negative_mode=mode,
neg_sample_size=neg_size,
exclude_positive=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment