Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
913e3249
Unverified
Commit
913e3249
authored
Dec 27, 2019
by
xiang song(charlie.song)
Committed by
GitHub
Dec 27, 2019
Browse files
Add device check for sampler input (#1145)
current samplers only support working on CPU
parent
d57ff78d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
0 deletions
+56
-0
src/graph/sampler.cc
src/graph/sampler.cc
+36
-0
src/graph/sampler/metapath.cc
src/graph/sampler/metapath.cc
+6
-0
src/graph/sampler/randomwalk.cc
src/graph/sampler/randomwalk.cc
+12
-0
tests/compute/test_sampler.py
tests/compute/test_sampler.py
+2
-0
No files found.
src/graph/sampler.cc
View file @
913e3249
...
...
@@ -877,6 +877,10 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
CHECK
(
aten
::
IsValidIdArray
(
seed_nodes
));
CHECK_EQ
(
seed_nodes
->
ctx
.
device_type
,
kDLCPU
)
<<
"UniformSampler only support CPU sampling"
;
std
::
vector
<
NodeFlow
>
nflows
=
NeighborSamplingImpl
<
float
>
(
gptr
,
seed_nodes
,
batch_start_id
,
batch_size
,
max_num_workers
,
expand_factor
,
num_hops
,
neigh_type
,
add_self_loop
,
nullptr
);
...
...
@@ -901,12 +905,18 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
CHECK
(
aten
::
IsValidIdArray
(
seed_nodes
));
CHECK_EQ
(
seed_nodes
->
ctx
.
device_type
,
kDLCPU
)
<<
"NeighborSampler only support CPU sampling"
;
std
::
vector
<
NodeFlow
>
nflows
;
CHECK
(
probability
->
dtype
.
code
==
kDLFloat
)
<<
"transition probability must be float"
;
CHECK
(
probability
->
ndim
==
1
)
<<
"transition probability must be a 1-dimensional vector"
;
CHECK_EQ
(
probability
->
ctx
.
device_type
,
kDLCPU
)
<<
"NeighborSampling only support CPU sampling"
;
ATEN_FLOAT_TYPE_SWITCH
(
probability
->
dtype
,
...
...
@@ -947,6 +957,13 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
CHECK
(
aten
::
IsValidIdArray
(
seed_nodes
));
CHECK_EQ
(
seed_nodes
->
ctx
.
device_type
,
kDLCPU
)
<<
"LayerSampler only support CPU sampling"
;
CHECK
(
aten
::
IsValidIdArray
(
layer_sizes
));
CHECK_EQ
(
layer_sizes
->
ctx
.
device_type
,
kDLCPU
)
<<
"LayerSampler only support CPU sampling"
;
const
dgl_id_t
*
seed_nodes_data
=
static_cast
<
dgl_id_t
*>
(
seed_nodes
->
data
);
const
int64_t
num_seeds
=
seed_nodes
->
shape
[
0
];
const
int64_t
num_workers
=
std
::
min
(
max_num_workers
,
...
...
@@ -1570,6 +1587,14 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
CHECK
(
aten
::
IsValidIdArray
(
seed_edges
));
CHECK_EQ
(
seed_edges
->
ctx
.
device_type
,
kDLCPU
)
<<
"UniformEdgeSampler only support CPU sampling"
;
if
(
relations
->
shape
[
0
]
>
0
)
{
CHECK
(
aten
::
IsValidIdArray
(
relations
));
CHECK_EQ
(
relations
->
ctx
.
device_type
,
kDLCPU
)
<<
"WeightedEdgeSampler only support CPU sampling"
;
}
BuildCoo
(
*
gptr
);
auto
o
=
std
::
make_shared
<
UniformEdgeSamplerObject
>
(
gptr
,
...
...
@@ -1842,11 +1867,22 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
auto
gptr
=
std
::
dynamic_pointer_cast
<
ImmutableGraph
>
(
g
.
sptr
());
CHECK
(
gptr
)
<<
"sampling isn't implemented in mutable graph"
;
CHECK
(
aten
::
IsValidIdArray
(
seed_edges
));
CHECK_EQ
(
seed_edges
->
ctx
.
device_type
,
kDLCPU
)
<<
"WeightedEdgeSampler only support CPU sampling"
;
CHECK
(
edge_weight
->
dtype
.
code
==
kDLFloat
)
<<
"edge_weight should be FloatType"
;
CHECK
(
edge_weight
->
dtype
.
bits
==
32
)
<<
"WeightedEdgeSampler only support float weight"
;
CHECK_EQ
(
edge_weight
->
ctx
.
device_type
,
kDLCPU
)
<<
"WeightedEdgeSampler only support CPU sampling"
;
if
(
node_weight
->
shape
[
0
]
>
0
)
{
CHECK
(
node_weight
->
dtype
.
code
==
kDLFloat
)
<<
"node_weight should be FloatType"
;
CHECK
(
node_weight
->
dtype
.
bits
==
32
)
<<
"WeightedEdgeSampler only support float weight"
;
CHECK_EQ
(
node_weight
->
ctx
.
device_type
,
kDLCPU
)
<<
"WeightedEdgeSampler only support CPU sampling"
;
}
if
(
relations
->
shape
[
0
]
>
0
)
{
CHECK
(
aten
::
IsValidIdArray
(
relations
));
CHECK_EQ
(
relations
->
ctx
.
device_type
,
kDLCPU
)
<<
"WeightedEdgeSampler only support CPU sampling"
;
}
BuildCoo
(
*
gptr
);
...
...
src/graph/sampler/metapath.cc
View file @
913e3249
...
...
@@ -83,6 +83,12 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLMetapathRandomWalk")
const
IdArray
seeds
=
args
[
2
];
int
num_traces
=
args
[
3
];
CHECK
(
aten
::
IsValidIdArray
(
etypes
));
CHECK_EQ
(
etypes
->
ctx
.
device_type
,
kDLCPU
)
<<
"MetapathRandomWalk only support CPU sampling"
;
CHECK
(
aten
::
IsValidIdArray
(
seeds
));
CHECK_EQ
(
seeds
->
ctx
.
device_type
,
kDLCPU
)
<<
"MetapathRandomWalk only support CPU sampling"
;
const
auto
tl
=
MetapathRandomWalk
(
hg
.
sptr
(),
etypes
,
seeds
,
num_traces
);
*
rv
=
RandomWalkTracesRef
(
tl
);
});
...
...
src/graph/sampler/randomwalk.cc
View file @
913e3249
...
...
@@ -213,6 +213,10 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalk")
const
int
num_traces
=
args
[
2
];
const
int
num_hops
=
args
[
3
];
CHECK
(
aten
::
IsValidIdArray
(
seeds
));
CHECK_EQ
(
seeds
->
ctx
.
device_type
,
kDLCPU
)
<<
"RandomWalk only support CPU sampling"
;
*
rv
=
RandomWalk
(
g
.
sptr
().
get
(),
seeds
,
num_traces
,
num_hops
);
});
...
...
@@ -225,6 +229,10 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLRandomWalkWithRestart")
const
uint64_t
max_visit_counts
=
args
[
4
];
const
uint64_t
max_frequent_visited_nodes
=
args
[
5
];
CHECK
(
aten
::
IsValidIdArray
(
seeds
));
CHECK_EQ
(
seeds
->
ctx
.
device_type
,
kDLCPU
)
<<
"RandomWalkWithRestart only support CPU sampling"
;
*
rv
=
RandomWalkTracesRef
(
RandomWalkWithRestart
(
g
.
sptr
().
get
(),
seeds
,
restart_prob
,
visit_threshold_per_seed
,
max_visit_counts
,
max_frequent_visited_nodes
));
...
...
@@ -239,6 +247,10 @@ DGL_REGISTER_GLOBAL("sampler.randomwalk._CAPI_DGLBipartiteSingleSidedRandomWalkW
const
uint64_t
max_visit_counts
=
args
[
4
];
const
uint64_t
max_frequent_visited_nodes
=
args
[
5
];
CHECK
(
aten
::
IsValidIdArray
(
seeds
));
CHECK_EQ
(
seeds
->
ctx
.
device_type
,
kDLCPU
)
<<
"BipartiteSingleSidedRandomWalkWithRestart only support CPU sampling"
;
*
rv
=
RandomWalkTracesRef
(
BipartiteSingleSidedRandomWalkWithRestart
(
g
.
sptr
().
get
(),
seeds
,
restart_prob
,
visit_threshold_per_seed
,
...
...
tests/compute/test_sampler.py
View file @
913e3249
...
...
@@ -591,6 +591,7 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
for
pos_edges
,
neg_edges
in
EdgeSampler
(
g
,
batch_size
,
replacement
=
True
,
edge_weight
=
edge_weight
,
shuffle
=
True
,
negative_mode
=
mode
,
neg_sample_size
=
neg_size
,
exclude_positive
=
False
,
...
...
@@ -630,6 +631,7 @@ def check_weighted_negative_sampler(mode, exclude_positive, neg_size):
replacement
=
True
,
edge_weight
=
edge_weight
,
node_weight
=
node_weight
,
shuffle
=
True
,
negative_mode
=
mode
,
neg_sample_size
=
neg_size
,
exclude_positive
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment