Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a33fafb7
Unverified
Commit
a33fafb7
authored
Jul 12, 2023
by
peizhou001
Committed by
GitHub
Jul 12, 2023
Browse files
[Graphbolt] Change probs to name of attribute (#5968)
parent
e6e54304
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
49 deletions
+45
-49
graphbolt/include/graphbolt/csc_sampling_graph.h
graphbolt/include/graphbolt/csc_sampling_graph.h
+6
-5
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+11
-8
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+13
-8
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+15
-28
No files found.
graphbolt/include/graphbolt/csc_sampling_graph.h
View file @
a33fafb7
...
@@ -145,10 +145,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -145,10 +145,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* Otherwise, each value can be selected only once.
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* typically used when edge features are required.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* @param probs_name An optional string specifying the name of an edge
* probabilities or boolean mask associated with each neighboring edge of a
* attribute. This attribute tensor should contain (unnormalized)
* node. It must be a 1D floating-point or boolean tensor with the number of
* probabilities corresponding to each neighboring edge of a node. It must be
* elements equal to the number of edges.
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
*
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
* sampled graph's information.
...
@@ -156,7 +157,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -156,7 +157,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
return_eids
,
bool
replace
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
probs_
or_mask
)
const
;
torch
::
optional
<
std
::
string
>
probs_
name
)
const
;
/**
/**
* @brief Sample negative edges by randomly choosing negative
* @brief Sample negative edges by randomly choosing negative
...
...
graphbolt/src/csc_sampling_graph.cc
View file @
a33fafb7
...
@@ -132,16 +132,19 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
...
@@ -132,16 +132,19 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
SampledSubgraph
>
CSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
return_eids
,
bool
replace
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
probs_
or_mask
)
const
{
torch
::
optional
<
std
::
string
>
probs_
name
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
// Note probs will be passed as input for 'torch.multinomial' in deeper stack,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
;
// which doesn't support 'torch.half' and 'torch.bool' data types. To avoid
if
(
probs_name
.
has_value
()
&&
!
probs_name
.
value
().
empty
())
{
// crashes, convert 'probs_or_mask' to 'float32' data type.
probs_or_mask
=
edge_attributes_
.
value
().
at
(
probs_name
.
value
());
if
(
probs_or_mask
.
has_value
()
&&
// Note probs will be passed as input for 'torch.multinomial' in deeper
(
probs_or_mask
.
value
().
dtype
()
==
torch
::
kBool
||
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
probs_or_mask
.
value
().
dtype
()
==
torch
::
kFloat16
))
{
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if
(
probs_or_mask
.
value
().
dtype
()
==
torch
::
kBool
||
probs_or_mask
.
value
().
dtype
()
==
torch
::
kFloat16
)
{
probs_or_mask
=
probs_or_mask
.
value
().
to
(
torch
::
kFloat32
);
probs_or_mask
=
probs_or_mask
.
value
().
to
(
torch
::
kFloat32
);
}
}
}
// If true, perform sampling for each edge type of each node, otherwise just
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
// sample once for each node with no regard of edge types.
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
bool
consider_etype
=
(
fanouts
.
size
()
>
1
);
...
...
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
a33fafb7
...
@@ -219,7 +219,7 @@ class CSCSamplingGraph:
...
@@ -219,7 +219,7 @@ class CSCSamplingGraph:
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
return_eids
:
bool
=
False
,
return_eids
:
bool
=
False
,
probs_
or_mask
:
Optional
[
t
orch
.
Tenso
r
]
=
None
,
probs_
name
:
Optional
[
s
tr
]
=
None
,
)
->
torch
.
ScriptObject
:
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
subgraph.
...
@@ -252,11 +252,12 @@ class CSCSamplingGraph:
...
@@ -252,11 +252,12 @@ class CSCSamplingGraph:
Boolean indicating whether the edge IDs of sampled edges,
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
represented as a 1D tensor, should be returned. This is
typically used when edge features are required.
typically used when edge features are required.
probs_or_mask: torch.Tensor, optional
probs_name: str, optional
Optional tensor containing the (unnormalized) probabilities
An optional string specifying the name of an edge attribute. This
associated with each neighboring edge of a node. It must be a 1D
attribute tensor should contain (unnormalized) probabilities
floating-point or boolean tensor with the number of elements equal
corresponding to each neighboring edge of a node. It must be a 1D
to the number of edges.
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns
Returns
-------
-------
torch.classes.graphbolt.SampledSubgraph
torch.classes.graphbolt.SampledSubgraph
...
@@ -302,7 +303,11 @@ class CSCSamplingGraph:
...
@@ -302,7 +303,11 @@ class CSCSamplingGraph:
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
(
fanouts
>=
0
)
|
(
fanouts
==
-
1
)
),
"Fanouts should consist of values that are either -1 or
\
),
"Fanouts should consist of values that are either -1 or
\
greater than or equal to 0."
greater than or equal to 0."
if
probs_or_mask
is
not
None
:
if
probs_name
:
assert
(
probs_name
in
self
.
edge_attributes
),
f
"Unknown edge attribute '
{
probs_name
}
'."
probs_or_mask
=
self
.
edge_attributes
[
probs_name
]
assert
probs_or_mask
.
dim
()
==
1
,
"Probs should be 1-D tensor."
assert
probs_or_mask
.
dim
()
==
1
,
"Probs should be 1-D tensor."
assert
(
assert
(
probs_or_mask
.
size
(
0
)
==
self
.
num_edges
probs_or_mask
.
size
(
0
)
==
self
.
num_edges
...
@@ -316,7 +321,7 @@ class CSCSamplingGraph:
...
@@ -316,7 +321,7 @@ class CSCSamplingGraph:
torch
.
float64
,
torch
.
float64
,
],
"Probs should have a floating-point or boolean data type."
],
"Probs should have a floating-point or boolean data type."
return
self
.
_c_csc_graph
.
sample_neighbors
(
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
fanouts
.
tolist
(),
replace
,
return_eids
,
probs_
or_mask
nodes
,
fanouts
.
tolist
(),
replace
,
return_eids
,
probs_
name
)
)
def
sample_negative_edges_uniform
(
def
sample_negative_edges_uniform
(
...
...
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
a33fafb7
...
@@ -3,7 +3,6 @@ import tempfile
...
@@ -3,7 +3,6 @@ import tempfile
import
unittest
import
unittest
import
backend
as
F
import
backend
as
F
import
dgl
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
...
@@ -508,29 +507,8 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
...
@@ -508,29 +507,8 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
reason
=
"Graph is CPU only at present."
,
reason
=
"Graph is CPU only at present."
,
)
)
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"replace"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"probs_name"
,
[
"weight"
,
"mask"
])
"probs_or_mask"
,
def
test_sample_neighbors_probs
(
replace
,
probs_name
):
[
torch
.
tensor
([
2.5
,
0
,
8.4
,
0
,
0.4
,
1.2
,
2.5
,
0
,
8.4
,
0.5
,
0.4
,
1.2
]),
torch
.
tensor
(
[
True
,
False
,
True
,
False
,
True
,
True
,
True
,
False
,
True
,
True
,
True
,
True
,
]
),
],
)
def
test_sample_neighbors_probs
(
replace
,
probs_or_mask
):
"""Original graph in COO:
"""Original graph in COO:
1 0 1 0 1
1 0 1 0 1
1 0 1 1 0
1 0 1 1 0
...
@@ -546,8 +524,15 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
...
@@ -546,8 +524,15 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
edge_attributes
=
{
"weight"
:
torch
.
FloatTensor
(
[
2.5
,
0
,
8.4
,
0
,
0.4
,
1.2
,
2.5
,
0
,
8.4
,
0.5
,
0.4
,
1.2
]
),
"mask"
:
torch
.
BoolTensor
([
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
1
,
1
,
1
,
1
]),
}
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
edge_attributes
=
edge_attributes
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
...
@@ -555,7 +540,7 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
...
@@ -555,7 +540,7 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
nodes
,
nodes
,
fanouts
=
torch
.
tensor
([
2
]),
fanouts
=
torch
.
tensor
([
2
]),
replace
=
replace
,
replace
=
replace
,
probs_
or_mask
=
probs_or_mask
,
probs_
name
=
probs_name
,
)
)
# Verify in subgraph.
# Verify in subgraph.
...
@@ -587,8 +572,10 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
...
@@ -587,8 +572,10 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
indptr
[
-
1
]
==
len
(
indices
)
edge_attributes
=
{
"probs_or_mask"
:
probs_or_mask
}
# Construct CSCSamplingGraph.
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
)
graph
=
gb
.
from_csc
(
indptr
,
indices
,
edge_attributes
=
edge_attributes
)
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
...
@@ -596,7 +583,7 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
...
@@ -596,7 +583,7 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
nodes
,
nodes
,
fanouts
=
torch
.
tensor
([
5
]),
fanouts
=
torch
.
tensor
([
5
]),
replace
=
replace
,
replace
=
replace
,
probs_
or_mask
=
probs_or_mask
,
probs_
name
=
"
probs_or_mask
"
,
)
)
# Verify in subgraph.
# Verify in subgraph.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment